diff --git a/.gitignore b/.gitignore index 5afe375f46f07b3b557ae23f75740b337517d3bd..1ef4c297ee4f369775c13b32a46a55887de719e7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__ *.swp .vscode/ cmake_build/ +tensorflow/contrib/cmake/_build/ .idea/** /build/ [Bb]uild/ @@ -30,6 +31,7 @@ Podfile.lock xcuserdata/** /api_init_files_list.txt /estimator_api_init_files_list.txt +*.whl # Android .gradle diff --git a/CODEOWNERS b/CODEOWNERS index b9f0313cc6d59d3fbdcd014e1a528126d863075a..78f80c8d718983f00fd5010c3fe5d561124d3714 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,64 @@ -# NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -# /tensorflow/core/platform/windows/ @mrry -# /tensorflow/java/ @asimshankar -# /tensorflow/tensorboard/ @jart @dandelionmane -# /tensorflow/tools/docs/ @markdaoust +/tenosrflow/core/debug @caisq +/tensorflow/core/platform/windows/ @mrry +/tensorflow/go @asimshankar +/tensorflow/java/ @asimshankar +/tensorflow/python/debug @caisq +/tensorflow/python/tools/api/generator/ @annarev +/tensorflow/tensorboard/ @jart +/tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: /tensorflow/contrib/avro/ -# /tensorflow/contrib/batching/ @alextp @chrisolston -# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon -# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva -# /tensorflow/contrib/cmake/ @mrry @benoitsteiner -# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi -# /tensorflow/contrib/crf/ @kentonl -# /tensorflow/contrib/data/ @mrry -# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo -# /tensorflow/contrib/ffmpeg/ @fredbertsch -# NEED OWNER: /tensorflow/contrib/framework/ -# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/all_reduce +/tensorflow/contrib/batching/ @alextp @chrisolston +/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +/tensorflow/contrib/checkpoint/ @allenlavoie +/tensorflow/contrib/contrib/cluster_resolver/ @frankchn +/tensorflow/contrib/cmake/ @mrry +/tensorflow/contrib/copy_graph/ @tucker @poxvoculi +/tensorflow/contrib/crf/ @kentonl +/tensorflow/contrib/data/ @mrry +/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn +/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +/tensorflow/contrib/eager @alextp @asimshankar +/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +/tensorflow/contrib/ffmpeg/ @fredbertsch +/tensorflow/contrib/framework/ @ebrevdo +/tensorflow/contrib/gan/ @joel-shor +/tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ -# /tensorflow/contrib/hvx/ @satok16 -# /tensorflow/contrib/integrate/ @shoyer -# /tensorflow/contrib/kernel_methods/ @petrosmol -# /tensorflow/contrib/ios_examples/ @petewarden -# /tensorflow/contrib/labeled_tensor/ @shoyer -# /tensorflow/contrib/layers/ @fchollet @martinwicke -# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -# /tensorflow/contrib/linalg/ @langmore -# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis -# /tensorflow/contrib/lookup/ @ysuematsu @andreasst -# /tensorflow/contrib/losses/ @alextp @ispirmustafa -# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg -# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq -# /tensorflow/contrib/opt/ @strategist333 -# /tensorflow/contrib/pi_examples/ @maciekcc -# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman -# /tensorflow/contrib/rnn/ @ebrevdo -# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh -# /tensorflow/contrib/seq2seq/ @lukaszkaiser -# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh -# /tensorflow/contrib/slim/ @sguada @thenbasilmanran -# /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -# /tensorflow/contrib/testing/ @dandelionmane -# /tensorflow/contrib/timeseries/ @allenlavoie -# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu -# /tensorflow/contrib/training/ @joel-shor @ebrevdo -# /tensorflow/contrib/util/ @sherrym +/tensorflow/contrib/hvx/ @satok16 +/tensorflow/contrib/integrate/ @shoyer +/tensorflow/contrib/kernel_methods/ @petrosmol +/tensorflow/contrib/ios_examples/ @petewarden +/tensorflow/contrib/labeled_tensor/ @shoyer +/tensorflow/contrib/layers/ @fchollet @martinwicke +/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +/tensorflow/contrib/linalg/ @langmore +/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +/tensorflow/contrib/lookup/ @ysuematsu @andreasst +/tensorflow/contrib/losses/ @alextp @ispirmustafa +/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +/tensorflow/contrib/opt/ @strategist333 @alextp +/tensorflow/contrib/pi_examples/ @maciekcc +/tensorflow/contrib/quantization/ @petewarden +/tensorflow/contrib/rnn/ @ebrevdo @scottzhu +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang +/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +/tensorflow/contrib/slim/ @sguada @thenbasilmanran +/tensorflow/contrib/stateless/ @girving @alextp +/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank +/tensorflow/contrib/tensorrt/ @aaroey +# NEED OWNER: /tensorflow/contrib/testing/ +/tensorflow/contrib/timeseries/ @allenlavoie +/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj +/tensorflow/contrib/training/ @joel-shor @ebrevdo +/tensorflow/contrib/util/ @sherrym + +/third_party/systemlibs/ @perfinion diff --git a/README.md b/README.md index 669ff5b711c62455f48038743ca1e089fa23d9e6..e3092e551e32d7f01e9bebd65323d1b5691f0269 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,8 @@ The TensorFlow project strives to abide by generally accepted best practices in | **Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-cpu.html) | [pypi](https://pypi.org/project/tf-nightly/) | | **Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/windows-gpu.html) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | | **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | +| **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv6l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) | +| **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py2.html) [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py2](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp27-none-linux_armv7l.whl) [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) | ### Community Supported Builds @@ -100,16 +102,16 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | | **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6| ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.cloud.google.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | ## For more information -* [Tensorflow Blog](https://medium.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [Tensorflow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) diff --git a/configure.py b/configure.py index bf570a9fa394f8fb7ef98f57007b656afd0c466c..361bd4764dc5c1900be7378f51c00aedf6f2ce41 100644 --- a/configure.py +++ b/configure.py @@ -45,7 +45,7 @@ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine() _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' -_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 @@ -848,7 +848,7 @@ def set_tf_cuda_version(environ_cp): cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths] if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): - break + break # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % @@ -1399,8 +1399,11 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') -def set_build_strip_flag(): - write_to_bazelrc('build --strip=always') +def set_system_libs_flag(environ_cp): + syslibs = environ_cp.get('TF_SYSTEM_LIBS', '') + syslibs = ','.join(sorted(syslibs.split(','))) + if syslibs and syslibs != '': + write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) def set_windows_build_flags(environ_cp): @@ -1505,6 +1508,8 @@ def main(): False, 'gdr') set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', False, 'verbs') + set_build_var(environ_cp, 'TF_NEED_NGRAPH', 'nGraph', + 'with_ngraph_support', False, 'ngraph') set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': @@ -1538,6 +1543,10 @@ def main(): if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': # Set up which clang we should use as the cuda / host compiler. set_clang_cuda_compiler_path(environ_cp) + else: + # Use downloaded LLD for linking. + write_to_bazelrc('build:cuda_clang --config=download_clang_use_lld') + write_to_bazelrc('test:cuda_clang --config=download_clang_use_lld') else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1559,7 +1568,7 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_build_strip_flag() + set_system_libs_flag(environ_cp) if is_windows(): set_windows_build_flags(environ_cp) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f1000c1bffd07d090315f6ebbbd3ea504e710df9..b5e0a4e98b0c183454afa4a4389dcf73802b219b 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -23,6 +23,14 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load( + "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", + "TENSORFLOW_API_INIT_FILES_V1", # @unused +) +load( + "//third_party/ngraph:build_defs.bzl", + "if_ngraph", +) # Config setting used when building for products # which requires restricted licenses to be avoided. @@ -411,6 +419,14 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag is set from the configure step when the user selects with nGraph option. +# By default it should be false +config_setting( + name = "with_ngraph_support", + values = {"define": "with_ngraph_support=true"}, + visibility = ["//visibility:public"], +) + package_group( name = "internal", packages = [ @@ -431,7 +447,7 @@ filegroup( name = "intel_binary_blob", data = if_mkl_ml( [ - "//third_party/intel_mkl_ml", + "//third_party/mkl:intel_binary_blob", ], ), ) @@ -563,7 +579,7 @@ tf_cc_shared_object( "//tensorflow/cc:scope", "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", - ], + ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), ) exports_files( @@ -577,6 +593,7 @@ gen_api_init_files( name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], api_version = 1, + output_files = TENSORFLOW_API_INIT_FILES_V1, root_init_template = "api_template.__init__.py", ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446..21677512b63828fa2035527ed573bf4dc4603085 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -28,7 +28,8 @@ contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top -app.flags = flags # pylint: disable=undefined-variable +from tensorflow.python.platform import app # pylint: disable=g-import-not-at-top +app.flags = flags del absolute_import del division diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 8a9301d584775cff3ae315e6fd856b00d1734248..109b3b37aace34914e5307981ead597c25c7fb8f 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -117,6 +117,7 @@ tf_cuda_library( deps = [ ":c_api", ":c_api_internal", + "//tensorflow/c/eager:c_api", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/contrib/tpu:all_ops", "//tensorflow/core:core_cpu", @@ -127,6 +128,15 @@ tf_cuda_library( ], ) +cc_library( + name = "c_api_headers", + hdrs = [ + "c_api.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], +) + exports_files( [ "version_script.lds", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 19ccb6e71d2f3021c1ce5c8905d8a72059c1cfcb..173bbea596a4276559f5cd67824e5cc75313985c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -202,7 +202,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->len_ = len; if (dtype != TF_STRING && dtype != TF_RESOURCE && tensorflow::DataTypeCanUseMemcpy(static_cast(dtype)) && - reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { + reinterpret_cast(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) != + 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste // (any alignment requirements will be taken care of by TF_TensorToTensor @@ -1239,7 +1240,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; - func_name.set_name(std::string(value, value + length)); + func_name.set_name(string(value, value + length)); desc->node_builder.Attr(attr_name, func_name); } @@ -2064,7 +2065,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); + tf_results->missing_unused_key_names_data.emplace_back(id.first); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 6617c5a572e90e78369f73d714f39942f213040f..09d482d6df45aa95a2f463f1c9601048bea24c04 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" // -------------------------------------------------------------------------- // Experimental C API for TensorFlow. @@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, TF_Tensor* tensor, TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( + const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index aa2a537f03be31ae45ff3d6f7815b449d661cf9c..03516c39dc970aa23967107d3a0446da94669465 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) { TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0, nullptr, 0, run_metadata, s); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(std::string("Session was not created with a graph before Run()!"), - std::string(TF_Message(s))); + EXPECT_EQ("Session was not created with a graph before Run()!", + string(TF_Message(s))); TF_DeleteBuffer(run_metadata); TF_DeleteBuffer(run_options); @@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), - std::string(TF_Message(s_))); + EXPECT_EQ("Operation 'add' has no attr named '_class'.", + string(TF_Message(s_))); return; } EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); @@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) { input.flat()(i) = example.SerializeAsString(); } - const tensorflow::string input_op_name = - std::string(tensorflow::ParseTensorName(input_name).first); + const tensorflow::string input_op_name( + tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - const tensorflow::string output_op_name = - std::string(tensorflow::ParseTensorName(output_name).first); + const tensorflow::string output_op_name( + tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 74bc25a491ac01cb725d1c004197e48727c30230..d3311f0cd06f2b151c3567735eb41b5baf72e102 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - std::string(v2_reader_->key()) /* full var's name */, + string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; + if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = std::string(v2_reader_->key()); + string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 4de1300a7f66a8b4eb8074819432fd7dd597bb15..91654c8d4fb8067ae1fb525ebaa6c54689085545 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_CHECKPOINT_READER_H -#define TENSORFLOW_C_CHECKPOINT_READER_H +#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_ +#define TENSORFLOW_C_CHECKPOINT_READER_H_ #include #include @@ -79,4 +79,4 @@ class CheckpointReader { } // namespace checkpoint } // namespace tensorflow -#endif // TENSORFLOW_C_CHECKPOINT_READER_H +#endif // TENSORFLOW_C_CHECKPOINT_READER_H_ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc old mode 100644 new mode 100755 index dfb1c9a37644c726e1eabab775593596d5b556b9..77e3878a94eddfa1dfd53844916f453d70bcac4a --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -244,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, } void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, - unsigned char async) { - options->async = async; + unsigned char enable) { + options->async = enable; } void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { @@ -253,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( } TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char async, + unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(async); + status->status = ctx->context.SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } @@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { new tensorflow::IntraProcessRendezvous(device_mgr.get()); return new TFE_Context(opts->session_options.options, opts->policy, - opts->async, std::move(device_mgr), r); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +} + +TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, + TF_Session* sess, TF_Status* status) { + const tensorflow::DeviceMgr* device_mgr = nullptr; + status->status = sess->session->LocalDeviceManager(&device_mgr); + if (!status->status.ok()) return nullptr; + tensorflow::Rendezvous* r = + new tensorflow::IntraProcessRendezvous(device_mgr); + return new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, /*device_mgr_owned*/ false, + r); } void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h old mode 100644 new mode 100755 index a0ebc6fa0a22ed61be91c2974352c2988fb4cd92..eec2750d6eb3bceed8da3ed44812ac2e8fd5c877 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetAsyncForThread. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, - unsigned char async); + unsigned char enable); TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*); // Overrides the execution mode (sync/async) for the current thread. TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, - unsigned char async, + unsigned char enable, TF_Status* status); // A tensorflow.ServerDef specifies remote workers (in addition to the current diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index a5c0681e2e4eddae08954d9d0178ca96a3f8f29a..104d52430cf7aa14d4d2a335a1b96e667f21ce87 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -62,15 +62,14 @@ struct TFE_ContextOptions { }; struct TFE_Context { - explicit TFE_Context(const tensorflow::SessionOptions& opts, - TFE_ContextDevicePlacementPolicy default_policy, - bool async, - std::unique_ptr device_mgr, - tensorflow::Rendezvous* rendezvous) + TFE_Context(const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned, + tensorflow::Rendezvous* rendezvous) : context(opts, static_cast( default_policy), - async, std::move(device_mgr), rendezvous) {} + async, device_mgr, device_mgr_owned, rendezvous) {} tensorflow::EagerContext context; }; diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 71d5f3613c89762633113b4e1dfb82b8199a1cd1..7126227cf529023eadf38984668a40118641bb1b 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1471,4 +1471,61 @@ void BM_ReadVariable(int iters) { } BENCHMARK(BM_ReadVariable); +TEST(CAPI, StringAttributes) { + // Test that TFE_OpSetAttrString doesn't hold on to the value after it + // returns. + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::vector dims(4, 1); + TFE_Op* op = TFE_NewOp(ctx, "AvgPool", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* tensor = + TF_AllocateTensor(TF_FLOAT, dims.data(), dims.size(), sizeof(float)); + float tensor_data[] = {1}; + memcpy(TF_TensorData(tensor), tensor_data, TF_TensorByteSize(tensor)); + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandle(tensor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, tensor_handle, status); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(tensor_handle); + + std::vector values(4, 1); + TFE_OpSetAttrIntList(op, "ksize", values.data(), values.size()); + TFE_OpSetAttrIntList(op, "strides", values.data(), values.size()); + + const int BUFFER_SIZE = 10; + char buffer[BUFFER_SIZE]; + std::strncpy(buffer, "VALID", BUFFER_SIZE); + TFE_OpSetAttrString(op, "padding", buffer, std::strlen(buffer)); + // Overwriting value in "buffer", should be fine since TFE_Op + // shouldn't be holding on to it. + std::strncpy(buffer, "NHWC", BUFFER_SIZE); + TFE_OpSetAttrString(op, "data_format", buffer, std::strlen(buffer)); + + TFE_OpSetAttrType(op, "T", TF_FLOAT); + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(op, &retvals[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + tensor = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(4, TF_TensorByteSize(tensor)); + TF_DeleteTensor(tensor); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(op); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); +} } // namespace diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1adb0458c35193117b5fa5cfe9ceffbaaf699af7..ce038a4b57b2699c6d09fcf75ef41cecec4e97b8 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -440,6 +440,15 @@ Status InitialGradients(const VSpace& vspace, return Status::OK(); } +gtl::FlatMap>* FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = new gtl::FlatMap>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + } // namespace // If over kMinAggregateCount gradients are accumulated and the total @@ -485,10 +494,6 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - gtl::FlatMap> functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -509,8 +514,8 @@ Status GradientTape::ComputeGradient( auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { auto func_name_it = - functions_accept_none_for_indices.find(trace.op_type); - if (func_name_it != functions_accept_none_for_indices.end() && + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && func_name_it->second.find(i) != func_name_it->second.end()) { out_gradients.push_back(nullptr); } else { diff --git a/tensorflow/c/tf_status_helper.h b/tensorflow/c/tf_status_helper.h index 86e687df205617018d94c19ac34fdc3bf54dcc6f..7661a01de4afcefbb66b33a05534e22d2ba1baa0 100644 --- a/tensorflow/c/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H -#define TENSORFLOW_C_TF_STATUS_HELPER_H +#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_ +#define TENSORFLOW_C_TF_STATUS_HELPER_H_ #include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/status.h" @@ -29,4 +29,4 @@ Status StatusFromTF_Status(const TF_Status* tf_status); } // namespace tensorflow -#endif // TENSORFLOW_C_TF_STATUS_HELPER_H +#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index dfdef88945deca376368edd6f7aa322b1e1cbf94..a32d1b1eb50fc715084f5ee663a732770db1883c 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return std::string(name); + return string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, @@ -508,15 +508,6 @@ bool HasOptionalAttrs( return false; } -const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { - for (int i = 0; i < api_def.in_arg_size(); ++i) { - if (api_def.in_arg(i).name() == name) { - return &api_def.in_arg(i); - } - } - return nullptr; -} - struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8c886f31711eb014fb9e9d600c9c78cf22073f71..7f6ac4cae78d8d6e118837fce9ae5270336cdc89 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -225,7 +225,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(std::string(s)); + current_constraints.emplace(s); } } } else { diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 5dcf00857df0eabd4e99f2782c1910515a9be265..1329b568ab8d4cc5cc5eed554e74bf1100d9bdcf 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -441,21 +441,20 @@ Status RealDivGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("RealDiv", RealDivGrad); -Status UnsafeDivGrad(const Scope& scope, const Operation& op, - const std::vector& grad_inputs, - std::vector* grad_outputs) { +Status DivNoNanGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { auto x_1 = ConjugateHelper(scope, op.input(0)); auto x_2 = ConjugateHelper(scope, op.input(1)); // y = x_1 / x_2 // dy/dx_1 = 1/x_2 // dy/dx_2 = -x_1/x_2^2 - auto gx_1 = UnsafeDiv(scope, grad_inputs[0], x_2); - auto gx_2 = - Mul(scope, grad_inputs[0], - UnsafeDiv(scope, UnsafeDiv(scope, Neg(scope, x_1), x_2), x_2)); + auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2); + auto gx_2 = Mul(scope, grad_inputs[0], + DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2)); return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); } -REGISTER_GRADIENT_OP("UnsafeDiv", UnsafeDivGrad); +REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad); Status SquaredDifferenceGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 88aef1fab410e11aa17a9e44578f5db95ed6e52b..c16938322c3555939ace1013f3bb95c5689b503e 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -33,6 +33,7 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; +using ops::DivNoNan; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -48,7 +49,6 @@ using ops::SegmentSum; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::UnsafeDiv; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions @@ -854,13 +854,13 @@ TEST_F(NaryGradTest, RealDiv) { RunTest({x}, {x_shape}, {y}, {x_shape}); } -TEST_F(NaryGradTest, UnsafeDiv) { +TEST_F(NaryGradTest, DivNoNan) { { TensorShape x_shape({3, 2, 5}); const auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); // Test x / (1 + |x|) rather than x_1 / x_2 to avoid triggering large // division errors in the numeric estimator used by the gradient checker. - const auto y = UnsafeDiv( + const auto y = DivNoNan( scope_, x, Add(scope_, Const(scope_, 1), Abs(scope_, x))); RunTest({x}, {x_shape}, {y}, {x_shape}); } @@ -868,7 +868,7 @@ TEST_F(NaryGradTest, UnsafeDiv) { // Return 0 gradient (rather than NaN) for division by zero. const auto x = Placeholder(scope_, DT_FLOAT); const auto zero = Const(scope_, 0.0); - const auto y = UnsafeDiv(scope_, x, zero); + const auto y = DivNoNan(scope_, x, zero); std::vector grad_outputs; TF_EXPECT_OK(AddSymbolicGradients(scope_, {y}, {x}, &grad_outputs)); diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 3830416159158cca8bfb8422c2959b49fa42406d..c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(main_op_name)}, nullptr /* outputs */, &run_metadata, session); } return Status::OK(); @@ -182,12 +182,12 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_path_tensor.scalar()() = variables_path; std::vector> inputs = { - {variable_filename_const_op_name.ToString(), variables_path_tensor}}; + {string(variable_filename_const_op_name), variables_path_tensor}}; AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(restore_op_name)}, nullptr /* outputs */, &run_metadata, session); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 1899a32e4dc5487875f091fece6acf0c44c9243f..6c29f09cde7ee17c11cb44ce48d8e9128daae4d0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -55,6 +54,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -71,6 +73,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -99,6 +102,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -188,11 +192,13 @@ cc_library( srcs = ["embedded_protocol_buffers.cc"], hdrs = ["embedded_protocol_buffers.h"], deps = [ - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 89fefdad54fabcc953e72c6aa7a2361468b61259..2b1ce34b3770a47e31d4f623b1b4f4650206737e 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,17 +19,18 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/types/span.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -141,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, } rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -157,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, // text-templating mechanism. string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { - str_util::ReplaceAllPairs(&code, rewrites); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll(rewrites, &code); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -570,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, - {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -594,8 +596,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", - str_util::Join(buffer_infos_as_strings, ",\n")}}; - str_util::ReplaceAllPairs(header, rewrites); + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; + absl::StrReplaceAll(rewrites, header); return Status::OK(); } @@ -617,7 +619,8 @@ Status GenerateMetadata(const CodegenOpts& opts, if (opts.gen_program_shape) { program_shape = - tensorflow::MakeUnique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); + // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save // space. diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 60d59ae996e8f7ec490c98aeab05182626e61976..e3a53edb7368c209bea16a9e34b1f452a8ff4bf8 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -34,9 +34,9 @@ namespace { using ::tensorflow::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 4e27aafec7747655d8e4ea3ddd1788d495ca0710..f1e8e5c08482e15d989c19a43aa7c5f437cd091d 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -26,8 +28,6 @@ limitations under the License. #include "llvm/Support/TargetRegistry.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/tf2xla/str_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, " return proto;\n" " }()"; - str_util::ReplaceAllPairs( - &code, + return absl::StrReplaceAll( + code, { {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, }); - return code; } static StatusOr CodegenModule(llvm::TargetMachine* target_machine, @@ -97,7 +96,7 @@ static StatusOr> GetTargetMachineFromTriple(StringPiece target_triple) { std::string error; std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(target_triple)); + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { @@ -105,20 +104,20 @@ GetTargetMachineFromTriple(StringPiece target_triple) { error.c_str()); } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( normalized_triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions(), llvm::None)); } StatusOr CreateEmbeddedProtocolBuffers( StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed) { + absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; std::unique_ptr module_with_serialized_proto = - MakeUnique("embedded_data_module", llvm_context); + absl::make_unique("embedded_data_module", llvm_context); EmbeddedProtocolBuffers result; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4e194a6aba9a9efcad27c47c42e148d8e537ae68..4f940c019750f49da4ad2386aa4b23281cc5a9fc 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -20,8 +20,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -84,7 +84,7 @@ struct ProtobufToEmbed { // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( StringPiece target_triple, - gtl::ArraySlice protobufs_to_embed); + absl::Span protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 0ecc3feeb6fef1dd691ab2785b3221075a79ba88..723e9bec8afcfbf7ceeeb59c63e4e12442fdb7ab 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -187,6 +187,9 @@ tf_library( cpp_class = "MatMulAndAddCompWithProfiling", enable_xla_hlo_profiling = True, graph = "test_graph_tfmatmulandadd.pb", + tags = [ + "manual", + ], ) tf_library( @@ -226,5 +229,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 0c0c676ece78565e03578d3e33633c7e23b77669..dd2b151098f2054571ac32b8b506cbc00659588a 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #define EIGEN_USE_CUSTOM_THREAD_POOL +#include "absl/strings/str_split.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) { VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = - tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + absl::StrSplit(hlo_profile_as_string, '\n'); auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7be6c91cf30c87bbaf75402446bd169..f3c44e9dda8ce96a268420a7f4d0f22e50ddfe41 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e059f77563ba46d9df247c897ece7a1c14f7801a..df81f3c23e38a2ec2cea827cd0adb123855e7714 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -128,11 +128,11 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -191,6 +191,7 @@ cc_library( "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", + "@com_google_absl//absl/memory", ], ) @@ -235,6 +236,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/memory", ], ) @@ -283,6 +285,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) @@ -303,6 +306,52 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "resource_operation_safety_analysis", + srcs = ["resource_operation_safety_analysis.cc"], + hdrs = ["resource_operation_safety_analysis.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "resource_operation_safety_analysis_test", + srcs = ["resource_operation_safety_analysis_test.cc"], + deps = [ + ":common", + ":resource_operation_safety_analysis", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -331,11 +380,10 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", - "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -347,6 +395,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", ], ) @@ -355,12 +404,13 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/types:optional", ], ) @@ -433,6 +483,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -444,6 +495,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -524,6 +576,9 @@ tf_cuda_cc_test( ":common", ":xla_cluster_util", ":xla_fusion_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:graph", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index a2e6285339f9ed0bde8d72f5b4752b1ecc22f426..56b034a30b7bddb023e54ead22c91a7a18095d2d 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -125,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, const DataTypeVector& arg_types = (*fbody)->arg_types; std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { @@ -207,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, // device memory. // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory + // in device memory except for resources. MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } // Create the kernel. NameAttrList function; @@ -223,8 +230,8 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - *kernel = MakeUnique(&construction, constant_arg_indices, - resource_arg_indices, function); + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function); return s; } diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc index b75ab486b80e098bc0a59f9ea8cdbaa23a28fef9..73866607621cd745f6e640a14405daebf0dd9985 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/create_xla_launch_op.h" +#include "absl/memory/memory.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function_testlib.h" @@ -65,11 +66,11 @@ class CreateXlaLaunchOpTest : public ::testing::Test { for (const auto& fdef : flib) { *(proto.add_function()) = fdef; } - lib_def_ = - MakeUnique(OpRegistry::Global(), proto); + lib_def_ = absl::make_unique( + OpRegistry::Global(), proto); OptimizerOptions opts; - device_mgr_ = MakeUnique(devices_); - pflr_ = MakeUnique( + device_mgr_ = absl::make_unique(devices_); + pflr_ = absl::make_unique( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 62007e6115d3fb81def844fcfa462094e223f565..82aa03810bc0ecee8ae92ed6f286867eea893287 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -21,18 +22,79 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW +// ================== // // We map every output produced by each node in the TensorFlow graph (including // control dependence) into an instance of the Predicate class. Instances of // Predicate denote logical formulas and mapping a node `n` to a predicate -// `pred` implies that `n` is executed whenver `pred` is true. Then we can -// deduce mismatching liveness in the inputs to node by comparing the predicate -// those inputs are mapped to. +// `pred` implies that `n` is live whenever `pred` is true. Then we can deduce +// mismatching liveness in the inputs to node by comparing the predicate those +// inputs are mapped to. The core logic of this pass resides in creating the +// map from TensorFlow nodes to predicates. // -// Loops are handled pessimistically -- we map Merge nodes with backedges to -// uninterpreted symbols (the same kind we use to represent Switch and _Recv). -// Predicate equality has to hold over all possible assignments to these -// uninterpreted symbols. +// +// MAPPING NODES TO PREDICATES, MODULO CYCLES +// ------------------------------------------ +// +// If we ignore cycles for a moment, computing predicates is fairly +// straightforward. We traverse the graph in RPO, mapping each node to a +// predicate based on the predicates its inputs are mapped to. For instance a +// Merge(X, Y) node will be mapped to OR(PredicateFor(X), PredicateFor(Y)). +// Roughtly speaking, we abstract interpret each node on the "liveness" domain, +// where values in the domain represent if a tensor carries a dead signal or +// not. +// +// +// DEALING WITH CYCLES +// ------------------- +// +// We map Merge nodes that are the target of a backedge to AndRecurrence +// instances. An AndRecurrence with start() = S and step() = X, printed as +// {S,&,X}, *roughly* represents the infinite list of predicates +// [S,S&X,S&X&X,S&X&X, ...]. So {S,&,X} can be used to represent the predicate +// for Merge in a graph like: +// +// Init +// | +// v +// Merge <-----------+ +// | | +// v | +// Incr | +// | | +// v | +// Switch <- Cond | +// | | +// v (oidx: 1) | +// | | +// +---------------+ +// +// Where S is the predicate for Init and X is the predicate that asserts that +// Cond is true. {S,&,X} states that Merge is live on the first "iteration" iff +// S is true, live on the second iteration iff "S&X" is true, live on the third +// iteration iff "S&X&X" is true etc. There is a subtlety here, S&X&X would +// normally be equivalent to S&X which isn't quite what we want to represent. +// Instead we want {S,&,X} to denote the infinite list [S, S&X, +// S&X&X',S&X&X'&X'', ...] where X, X', X'' are predicates that assert Cond is +// true on iteration 0, 1, 2 respectively. This is made more precise in the +// comment on the AndRecurrence class. +// +// The general algorithm that deals with cycles does two RPO (reverse post +// order) passes over the graph. On the first pass it assigns a symbolic +// predicate to merge nodes with backedges. On the second pass it tries to +// pattern matche the predicates for the backedges of these merges and infer an +// AndRecurrence for the merge. +// +// In other words, we do a pessimistic data flow analysis where the data-flow +// lattice has two elements, Symbolic and NonSymbolic with Symbolic > +// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to +// converge. We don't do an optimistic data flow analysis to make pattern +// matching easier: if we assigned the predicate of the initial value to the +// merge during the first pass, on the second pass the backedge may see a +// simplified value that would be difficult to pattern match. +// +// We still use symbolic predicates for merges for which we can't pattern match +// on the backedge predicate. This is conservatively correct. namespace tensorflow { @@ -42,15 +104,21 @@ namespace { // above. class Predicate { public: - enum class Kind { kAnd, kOr, kNot, kSymbol }; + enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol }; virtual string ToString() const = 0; int64 hash() const { return hash_; } - virtual gtl::ArraySlice GetOperands() const = 0; + virtual absl::Span GetOperands() const = 0; virtual Kind kind() const = 0; virtual ~Predicate() {} + // Invokes func on p and on all of its operands recursively. Does not invoke + // `func` on the same Predicate instance twice. Aborts the search if `func` + // returns true. + template + static void Visit(Predicate* p, const FunctionTy& func); + protected: explicit Predicate(int64 hash) : hash_(hash) {} @@ -61,7 +129,7 @@ class Predicate { }; int64 HashPredicateSequence(Predicate::Kind kind, - gtl::ArraySlice preds) { + absl::Span preds) { int64 hash = ::tensorflow::hash()(kind); for (Predicate* pred : preds) { hash = Hash64Combine(hash, pred->hash()); @@ -86,13 +154,15 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -115,12 +185,14 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } - gtl::ArraySlice GetOperands() const override { return operands_; } - gtl::ArraySlice operands() const { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } + absl::Span operands() const { return operands_; } private: std::vector operands_; @@ -139,16 +211,54 @@ class NotPredicate : public Predicate { Kind kind() const override { return Kind::kNot; } Predicate* operand() const { return operands_[0]; } - gtl::ArraySlice GetOperands() const override { return operands_; } + absl::Span GetOperands() const override { + return operands_; + } private: std::array operands_; }; +// Represents an infinite list of predicates. +// +// An AndRecurrence with start = S and step = X is printed as {S,&,X} and stands +// for the list of predicates: +// +// S, S & GenSym(X,1), S & GenSym(X,1) & GenSym(X,2), ... +// +// where GenSym(, ) renames every SymbolPredicate in +// by appending to it, in effect creating a "fresh" symbol. +// This means {P,&,Q} is not equal to "P on the first iteration; P&Q on +// subsequent iterations". +class AndRecurrencePredicate : public Predicate { + public: + explicit AndRecurrencePredicate(Predicate* start, Predicate* step) + : Predicate(HashPredicateSequence(Kind::kAndRecurrence, {start, step})), + operands_({start, step}) {} + + Predicate* start() const { return operands_[0]; } + Predicate* step() const { return operands_[1]; } + + string ToString() const override { + return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); + } + + Kind kind() const override { return Kind::kAndRecurrence; } + + absl::Span GetOperands() const override { + return operands_; + } + + private: + std::array operands_; +}; + // Represents an uninterpreted symbol in a logical predicate. // // Two predicates are equivalent iff they are equivalent for all assignments to -// the symbols contained in them. +// the symbols contained in them, i.e. predicates are forall qualified over +// symbols. class SymbolPredicate : public Predicate { public: explicit SymbolPredicate(TensorId tensor_id, bool must_be_true) @@ -162,7 +272,7 @@ class SymbolPredicate : public Predicate { } Kind kind() const override { return Kind::kSymbol; } - gtl::ArraySlice GetOperands() const override { return {}; } + absl::Span GetOperands() const override { return {}; } // If `must_be_true()` is true this SymbolPredicate represents the proposition // "tensor_id() is live and evaluates to true". @@ -184,15 +294,38 @@ class SymbolPredicate : public Predicate { } }; +template +/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { + gtl::FlatSet visited; + std::vector stack; + + stack.push_back(p); + visited.insert(p); + + while (!stack.empty()) { + Predicate* current = stack.back(); + stack.pop_back(); + bool done = func(current); + if (done) { + return; + } + for (Predicate* op : current->GetOperands()) { + if (visited.insert(op).second) { + stack.push_back(op); + } + } + } +} + // Creates and owns Predicate instances. Simplifies predicates as it creates // them. class PredicateFactory { public: - Predicate* MakeAndPredicate(gtl::ArraySlice operands) { + Predicate* MakeAndPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/true); } - Predicate* MakeOrPredicate(gtl::ArraySlice operands) { + Predicate* MakeOrPredicate(absl::Span operands) { return MakeAndOrImpl(operands, /*is_and=*/false); } @@ -209,6 +342,21 @@ class PredicateFactory { } } + Predicate* MakeAndRecurrencePredicate(Predicate* start, Predicate* step) { + auto it = interned_and_rec_instances_.find({start, step}); + if (it != interned_and_rec_instances_.end()) { + return it->second.get(); + } + + std::unique_ptr new_pred = + Make(start, step); + Predicate* new_pred_ptr = new_pred.get(); + CHECK(interned_and_rec_instances_ + .emplace(SignatureForAndRec(start, step), std::move(new_pred)) + .second); + return new_pred_ptr; + } + Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) { SignatureForSymbol signature = {tensor_id, must_be_true}; auto it = interned_symbol_instances_.find(signature); @@ -234,7 +382,7 @@ class PredicateFactory { new PredicateT(std::forward(args)...)); } - Predicate* MakeAndOrImpl(gtl::ArraySlice operands, bool is_and); + Predicate* MakeAndOrImpl(absl::Span operands, bool is_and); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -247,8 +395,9 @@ class PredicateFactory { // for the owning pointers to predicate instances. using SignatureForAndOr = - std::pair>; + std::pair>; using SignatureForNot = Predicate*; + using SignatureForAndRec = std::pair; using SignatureForSymbol = std::pair; struct HashSignatureForAndOr { @@ -273,14 +422,16 @@ class PredicateFactory { interned_and_or_instances_; gtl::FlatMap> interned_not_instances_; + gtl::FlatMap> + interned_and_rec_instances_; gtl::FlatMap, HashSignatureForSymbol> interned_symbol_instances_; }; // Common code to create AndPredicate or OrPredicate instances. -Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, - bool is_and) { +Predicate* PredicateFactory::MakeAndOrImpl( + absl::Span operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; gtl::FlatSet simplified_ops_set; @@ -331,7 +482,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice operands, // NB! Because we'll use a non-owning reference to simplified_ops in the // key for interned_and_or_instances_ we need to be careful to std::move() // it all the way through. - gtl::ArraySlice operands_slice = simplified_ops; + absl::Span operands_slice = simplified_ops; std::unique_ptr new_pred = is_and ? Make(std::move(simplified_ops)) : Make(std::move(simplified_ops)); @@ -353,6 +504,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(); + Status PopulateWithReversePostOrder(absl::Span rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; gtl::FlatMap PredicateMapAsString() const; @@ -361,20 +513,40 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; std::vector GetIncomingPreds(Node* n, EdgeKind edge_kind); - void SetPred(Node* n, int output_idx, Predicate* pred) { - CHECK( - predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second); + + // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th + // bit of `should_revisit` if `pred` is different from the current predicate + // for the `output_idx` output of `n`. + void SetPredicate(Node* n, int output_idx, Predicate* pred, + std::vector* should_revisit) { + auto insert_result = + predicate_map_.insert({TensorId(n->name(), output_idx), pred}); + if (!insert_result.second && insert_result.first->second != pred) { + VLOG(4) << "For " << n->name() << ":" << output_idx << " from " + << insert_result.first->second->ToString() << " " + << insert_result.first->second << " to " << pred->ToString() + << " " << pred; + insert_result.first->second = pred; + if (should_revisit != nullptr) { + for (const Edge* e : n->out_edges()) { + (*should_revisit)[e->dst()->id()] = true; + } + } + } } - void SetPred(Node* n, gtl::ArraySlice output_idxs, Predicate* pred) { + + void SetPredicate(Node* n, absl::Span output_idxs, Predicate* pred, + std::vector* should_revisit) { for (int output_idx : output_idxs) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); } } - Status HandleSwitch(Node* n); - Status HandleMerge(Node* n); - Status HandleRecv(Node* n); - Status HandleGeneric(Node* n); + Status HandleSwitch(Node* n, std::vector* should_revisit); + Status HandleMerge(Node* n, std::vector* should_revisit); + Status HandleRecv(Node* n, std::vector* should_revisit); + Status HandleGeneric(Node* n, std::vector* should_revisit); + Status HandleNode(Node* n, std::vector* should_revisit); const Graph& graph_; gtl::FlatMap predicate_map_; @@ -397,14 +569,15 @@ std::vector DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()); + CHECK(it != predicate_map_.end()) << n->name(); incoming_preds.push_back(it->second); } } return incoming_preds; } -Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { +Status DeadnessAnalysisImpl::HandleSwitch(Node* n, + std::vector* should_revisit) { std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); const Edge* pred_edge; @@ -416,84 +589,252 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n) { // Output 0 is alive iff all inputs are alive and the condition is false. input_preds.push_back(false_switch); - SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 0, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); // Output 1 is alive iff all inputs are alive and the condition is true. input_preds.push_back(true_switch); - SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, 1, predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); input_preds.pop_back(); - // Control is alive iff any inputs are alive. - SetPred(n, Graph::kControlSlot, - predicate_factory_.MakeAndPredicate(input_preds)); + // Control is alive iff all inputs are alive. + SetPredicate(n, Graph::kControlSlot, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleMerge(Node* n) { +namespace { +const Edge* FindUniqueBackedge(Node* merge) { + CHECK(merge->IsMerge()); + const Edge* result = nullptr; + for (const Edge* e : merge->in_edges()) { + if (e->src()->IsNextIteration()) { + CHECK_EQ(result, nullptr) + << "Multiple backedges to " << merge->DebugString(); + result = e; + } + } + return result; +} + +// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step +// does not contain `symbolic_predicate` as an inner (not top-level) operand +// then returns `Step`. Otherwise returns nullptr. +Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, + Predicate* symbolic_predicate, + Predicate* backedge_predicate) { + CHECK(dynamic_cast(symbolic_predicate)); + if (backedge_predicate->kind() != Predicate::Kind::kAnd) { + return nullptr; + } + + std::vector and_ops; + absl::Span recurrent_pred_ops = + backedge_predicate->GetOperands(); + + bool found_sym = false; + for (Predicate* and_op : recurrent_pred_ops) { + // We want the `symbol_predicate` to be the one of the operands of + // `backedge_predicate`, + if (and_op == symbolic_predicate) { + found_sym = true; + continue; + } + + // but we don't want it to be present anywhere else in the formula. E.g. we + // don't want the recurrent predicate to be + // symbol_predicate&(X|symbol_predicate). + bool found_sym_as_inner_operand = false; + auto has_self_as_inner_operand = [&](Predicate* p) { + if (p == symbolic_predicate) { + found_sym_as_inner_operand = true; + return true; // Stop searching, we're done. + } + + // Continue searching. + return false; + }; + + Predicate::Visit(and_op, has_self_as_inner_operand); + if (found_sym_as_inner_operand) { + return nullptr; + } + and_ops.push_back(and_op); + } + + return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; +} +} // namespace + +Status DeadnessAnalysisImpl::HandleMerge(Node* n, + std::vector* should_revisit) { // Merge ignores deadness of its control inputs. A merge that isn't the - // target of a backedge has is alive iff any of its data inputs are. We treat - // the liveness of a merge that is the target of a backedge symbolically. + // target of a backedge has is alive iff any of its data inputs are. The + // liveness of a merge that is the target of a backedge can sometimes be + // represented using a AndRecurrencePredicate. If neither apply, we represent + // the liveness of the merge symbolically. + + bool has_unvisited_backedge = false; + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + has_unvisited_backedge |= !predicate_map_.count(InputEdgeToTensorId(e)); + } + } + + auto it = predicate_map_.find(TensorId(n->name(), 0)); + if (it == predicate_map_.end()) { + if (has_unvisited_backedge) { + // We're visiting this merge for the first time and it has an unvisited + // backedge. + Predicate* input_data_pred = predicate_factory_.MakeSymbolPredicate( + TensorId(n->name(), 0), /*must_be_true=*/false); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } - bool has_backedge = std::any_of( - n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) { - return !e->IsControlEdge() && e->src()->IsNextIteration(); - }); + // We're visiting this merge for the first time and it is a acyclic merge. + Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( + GetIncomingPreds(n, EdgeKind::kDataOnly)); + SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, + should_revisit); + return Status::OK(); + } - Predicate* input_data_pred = - has_backedge ? predicate_factory_.MakeSymbolPredicate( - TensorId(n->name(), 0), /*must_be_true=*/false) - : predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + if (it->second->kind() == Predicate::Kind::kSymbol) { + // Last time we visited this merge we only got a symbolic predicate because + // of an unvisited backedge. Try to pattern match the predicate expression + // for that backedge (which should be visited now) into an and recurrence + // for the merge node. + if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + if (Predicate* step = DeduceStepPredicate( + &predicate_factory_, it->second, + predicate_map_[InputEdgeToTensorId(unique_backedge)])) { + // If the predicate for the backedge is "Sym&X" where "Sym" is the + // predicate for the merge then the merge has predicate {S,&,X} where S + // is the predicate for the merge ignoring the backedge. + std::vector non_recurrent_inputs; + for (const Edge* e : n->in_edges()) { + if (e != unique_backedge) { + non_recurrent_inputs.push_back( + predicate_map_[InputEdgeToTensorId(e)]); + } + } - SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred); + Predicate* start = + predicate_factory_.MakeOrPredicate(non_recurrent_inputs); + Predicate* and_rec = + predicate_factory_.MakeAndRecurrencePredicate(start, step); + SetPredicate(n, {0, 1, Graph::kControlSlot}, and_rec, should_revisit); + return Status::OK(); + } + } + } return Status::OK(); } -Status DeadnessAnalysisImpl::HandleRecv(Node* n) { +Status DeadnessAnalysisImpl::HandleRecv(Node* n, + std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. std::vector input_preds = GetIncomingPreds(n, EdgeKind::kDataAndControl); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); - SetPred(n, {0, Graph::kControlSlot}, - predicate_factory_.MakeAndPredicate(input_preds)); + SetPredicate(n, {0, Graph::kControlSlot}, + predicate_factory_.MakeAndPredicate(input_preds), + should_revisit); return Status::OK(); } -Status DeadnessAnalysisImpl::HandleGeneric(Node* n) { +Status DeadnessAnalysisImpl::HandleGeneric(Node* n, + std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. Predicate* pred = predicate_factory_.MakeAndPredicate( GetIncomingPreds(n, EdgeKind::kDataAndControl)); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { - SetPred(n, output_idx, pred); + SetPredicate(n, output_idx, pred, should_revisit); + } + SetPredicate(n, Graph::kControlSlot, pred, should_revisit); + return Status::OK(); +} + +Status DeadnessAnalysisImpl::HandleNode(Node* n, + std::vector* should_revisit) { + if (n->IsSwitch()) { + TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); + } else if (n->IsMerge()) { + TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); + } else if (n->IsControlTrigger()) { + SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), + nullptr); + } else if (n->IsRecv() || n->IsHostRecv()) { + TF_RETURN_IF_ERROR(HandleRecv(n, should_revisit)); + } else if (n->IsNextIteration()) { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); + } else { + TF_RETURN_IF_ERROR(HandleGeneric(n, should_revisit)); } - SetPred(n, Graph::kControlSlot, pred); return Status::OK(); } Status DeadnessAnalysisImpl::Populate() { std::vector rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{}, + GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/[](const Edge& edge) { return !edge.src()->IsNextIteration(); }); + return PopulateWithReversePostOrder(rpo); +} +Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( + absl::Span rpo) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. + // + // We iterate over the graph twice, each time in RPO. On the first iteration + // merge nodes with backedges are mapped to symbolic predicates. On the + // second iteration we use the predicates assigned to the backedges in the + // previous iteration to infer a more precise predicate for the backedge merge + // nodes and all the nodes that transitively use it. + // + // We don't track the output indices for should_revisit. Instead, putting a + // node in `should_revisit` denotes that the deadness flowing out from any + // output from said node may have changed. This is fine; only switches + // propagate different deadness along different output edges, and since the + // delta is solely due to the input *values* (and not input deadness), the + // delta should not change in the second iteration. + std::vector should_revisit; + should_revisit.resize(graph_.num_node_ids()); for (Node* n : rpo) { - if (n->IsSwitch()) { - TF_RETURN_IF_ERROR(HandleSwitch(n)); - } else if (n->IsMerge()) { - TF_RETURN_IF_ERROR(HandleMerge(n)); - } else if (n->IsControlTrigger()) { - SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue()); - } else if (n->IsRecv() || n->IsHostRecv()) { - TF_RETURN_IF_ERROR(HandleRecv(n)); - } else { - TF_RETURN_IF_ERROR(HandleGeneric(n)); + VLOG(4) << "Visiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); + if (n->IsNextIteration()) { + // If this is a backedge for a merge node then remember to reprocess the + // merge the next time we run. + for (const Edge* e : n->out_edges()) { + if (e->dst()->IsMerge()) { + should_revisit[e->dst()->id()] = true; + } + } + } + } + + for (Node* n : rpo) { + // The nodes added to should_revisit in the previous loop need to be + // revisited now. Reprocesing these initial nodes may add *their* consumers + // to should_revisit, and these newly added nodes will also be processed by + // this very same loop. Since we're traversing the graph in reverse post + // order (producers before consumers) and HandleNode(n) can only ever add + // n's consumers to should_revisit, we won't "miss" an addition to + // should_revisit. + if (should_revisit[n->id()]) { + VLOG(4) << "Revisiting " << n->name(); + TF_RETURN_IF_ERROR(HandleNode(n, &should_revisit)); } } @@ -589,6 +930,15 @@ Status ComputePredicates(const Graph& graph, *out_predicate_map = impl.PredicateMapAsString(); return Status::OK(); } + +Status ComputePredicates(const Graph& graph, + absl::Span reverse_post_order, + PredicateMapTy* out_predicate_map) { + DeadnessAnalysisImpl impl(&graph); + TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); + *out_predicate_map = impl.PredicateMapAsString(); + return Status::OK(); +} } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index cdef4051108fdc5d063ab592676c7644989155bf..3df2679c629ce801fc6c9006415dcd27b40c078e 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -26,6 +26,14 @@ namespace deadness_analysis_internal { // testing purposes only. using PredicateMapTy = gtl::FlatMap; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. Makes deadness analysis visit the graph in the order +// specified in `reverse_post_order` which must be a valid RPO for the graph +// minus NextIteration->Merge edges. +Status ComputePredicates(const Graph& graph, + absl::Span reverse_post_order, + PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 6881095b51758d2e0b06c60021bc8c2860ac566e..28a56044d5e3795fc3ecf5d1092491b87cb90f01 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -32,12 +32,14 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { +using deadness_analysis_internal::ComputePredicates; +using deadness_analysis_internal::PredicateMapTy; + Status AnalyzeDeadness(Graph* graph, std::unique_ptr* result) { FixupSourceAndSinkEdges(graph); @@ -51,13 +53,73 @@ ops::Switch CreateSwitch(const Scope& root, const string& prefix) { return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate); } -Output CreateInductionVariable(const Scope& root, const string& prefix, - const string& frame_name, int32 init) { - Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init); +TensorId ControlOutputFor(const Output& o) { + return {o.node()->name(), Graph::kControlSlot}; +} + +void VLogGraphIfAsked(const Graph& graph) { + if (VLOG_IS_ON(3)) { + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + string serialized; + ::tensorflow::protobuf::TextFormat::PrintToString(graph_def, &serialized); + LOG(INFO) << serialized; + } +} + +struct InductionVarInfo { + Output induction_var; + Output loop_cond; +}; + +// Creates an induction variable with the following structure (simplified for +// brevity): +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// +> | Merge | -+ +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// | | LessThan10 | | +// | +---------------+ | +// | | | +// | | | +// | v | +// | +---------------+ | +// +----+- | Switch | <+ +// | | +---------------+ +// | | | +// | | | +// | | v +// | | +---------------+ +// | +- | AddOne | +// | +---------------+ +// | +---------------+ +// +-----> | Exit | +// +---------------+ +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, + const Output& initial_value) { Output enter_initial_value = ops::internal::Enter( root.WithOpName(prefix + "/enter"), initial_value, frame_name); - ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value}); + ops::Merge iv(root.WithOpName(prefix + "/iv"), + {enter_initial_value, enter_initial_value}); Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1); Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10); Output loop_cond_expr = @@ -66,16 +128,84 @@ Output CreateInductionVariable(const Scope& root, const string& prefix, ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr); ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); - Output iv_next = - ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by); + Output iv_next = ops::Add(root.WithOpName(prefix + "/ivnext"), + latch.output_true, increment_by); Output next_iteration = - ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next); + ops::NextIteration(root.WithOpName(prefix + "/next_iteration"), iv_next); - root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); root.graph()->AddControlEdge(iv.output.node(), increment_by.node()); root.graph()->AddControlEdge(iv.output.node(), final_value.node()); - return iv.output; + return {iv.output, loop_cond}; +} + +InductionVarInfo CreateInductionVariable(const Scope& root, + const string& prefix, + const string& frame_name, int32 init) { + return CreateInductionVariable( + root, prefix, frame_name, + ops::Const(root.WithOpName(prefix + "/init"), init)); +} + +// Creates an induction variable with the following structure: +// +// +---------------+ +// | initial_value | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Enter | +// +---------------+ +// | +// | +// v +// +---------------+ +// | Merge | <+ +// +---------------+ | +// | | +// | | +// v | +// +-----------+ +---------------+ | +// | loop_cond | --> | Switch | -+ +// +-----------+ +---------------+ +// | +// | +// v +// +---------------+ +// | Exit | +// +---------------+ +struct DependentInductionVar { + Output induction_var; + ops::Switch latch; +}; + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, const Output& value) { + Output enter_value = ops::internal::Enter(root.WithOpName(prefix + "/enter"), + value, frame_name); + ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output); + Output next_iteration = ops::NextIteration( + root.WithOpName(prefix + "/next_iteration"), latch.output_true); + CHECK(root.graph() + ->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1) + .ok()); + return {iv.output, latch}; +} + +DependentInductionVar CreateDependentLoopInvariantValue( + const Scope& root, const string& prefix, const string& frame_name, + const Output& loop_cond, int32 value) { + return CreateDependentLoopInvariantValue( + root, prefix, frame_name, loop_cond, + ops::Const(root.WithOpName(prefix + "/init"), value)); } TEST(DeadnessAnalysisTest, BasicPositive) { @@ -337,21 +467,224 @@ TEST(DeadnessAnalysisTest, HostRecv) { TEST(DeadnessAnalysisTest, Loop) { Scope root = Scope::NewRootScope().ExitOnError(); - Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0); - Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0); - Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1); + Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0).induction_var; + Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0).induction_var; + Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1).induction_var; Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1); Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2); - std::unique_ptr result; - TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - // NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have // noticed that. Today we are pessimistic here because we assign an // uninterpreted symbol to merges with backedges. - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + // In theory we should be able to tell that iv0/cond:0 and iv1/cond:0 + // produce the same deadness. But we're not that smart today. + EXPECT_EQ(predicate_map[ControlOutputFor(iv0)], "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv1)], "{#true,&,*iv1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv2)], "{#true,&,*iv2/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv0/cond:0})"); + EXPECT_EQ(predicate_map[ControlOutputFor(add1)], + "({#true,&,*iv1/cond:0} & {#true,&,*iv2/cond:0})"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + Output dependent_iv0 = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0) + .induction_var; + Output dependent_iv1 = + CreateDependentLoopInvariantValue(root, "div1", "frame", iv.loop_cond, 0) + .induction_var; + Output add0 = ops::Add(root.WithOpName("add0"), dependent_iv0, dependent_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{#true,&,(*iv0/cond:0 & iv0/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { + // Create a merge that "looks like" a loop but isn't really. It has a value + // that does not depend on the merge on its backedge. + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "frame", 0); + DependentInductionVar dependent_iv = + CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); + FixupSourceAndSinkEdges(root.graph()); + + // To make deadness analysis think that dependent_iv is a loop we need an RPO + // that visits the merge before the backedge. This is a legal RPO for + // deadness analysis since it ignores NextIteration->Merge edges during RPO. + // Right now dependent_iv has an edge from Merge to NextIteration so do the + // RPO with this edge in place. Then remove this edge to get our test case. + std::vector rpo; + GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{}, + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + TF_ASSERT_OK(root.graph()->UpdateEdge( + iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); + + VLogGraphIfAsked(*root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], + "div0/iv:0"); + } +} + +TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "frame", 0); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "frame", + ops::internal::Enter(root.WithOpName("iv_inner/enter"), + inner_value.output_true, "frame_inner")); + + Output dependent_outer_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv0", "frame", + iv_outer.loop_cond, 0) + .induction_var; + Output dependent_outer_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_outer_iv1", "frame", + iv_outer.loop_cond, 0) + .induction_var; + + Output dependent_inner_iv0 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv0", "frame", + iv_inner.loop_cond, dependent_outer_iv0) + .induction_var; + Output dependent_inner_iv1 = + CreateDependentLoopInvariantValue(root, "dependent_inner_iv1", "frame", + iv_inner.loop_cond, dependent_outer_iv1) + .induction_var; + + Output add0 = ops::Add(root.WithOpName("add0"), dependent_inner_iv0, + dependent_inner_iv1); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], + "{(*iv_outer/cond:0 & {#true,&,*iv_outer/cond:0}),&," + "*iv_inner/cond:0}"); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "{{#true,&,(iv_outer/iv:0 & *iv_outer/cond:0)},&," + "(*iv_inner/cond:0 & iv_inner/iv:0)}"); + } +} + +TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer_0 = + CreateInductionVariable(root, "iv_outer_0", "frame", 0); + ops::Switch inner_value_0(root.WithOpName("outer_0_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_0.loop_cond); + InductionVarInfo iv_inner_0 = CreateInductionVariable( + root, "iv_inner_0", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_0/enter"), + inner_value_0.output_true, "frame_inner")); + + InductionVarInfo iv_outer_1 = + CreateInductionVariable(root, "iv_outer_1", "frame", 1); + ops::Switch inner_init_value_1(root.WithOpName("outer_1_is_live"), + ops::Const(root.WithOpName("constant"), 5), + iv_outer_1.loop_cond); + InductionVarInfo iv_inner_1 = CreateInductionVariable( + root, "iv_inner_1", "frame", + ops::internal::Enter(root.WithOpName("iv_inner_1/enter"), + inner_init_value_1.output_true, "frame_inner")); + Output add0 = ops::Add(root.WithOpName("add0"), iv_inner_0.induction_var, + iv_inner_1.induction_var); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node())); + } + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_0.induction_var)], + "{#true,&,*iv_outer_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_0.induction_var)], + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer_1.induction_var)], + "{#true,&,*iv_outer_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner_1.induction_var)], + "{(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + "({(*iv_outer_1/cond:0 & {#true,&,*iv_outer_1/cond:0}),&," + "*iv_inner_1/cond:0} & " + "{(*iv_outer_0/cond:0 & {#true,&,*iv_outer_0/cond:0}),&," + "*iv_inner_0/cond:0})"); + } } TEST(DeadnessAnalysisTest, ControlInputs) { @@ -454,9 +787,8 @@ TEST(DeadnessAnalysisTest, RecvVsSwitchText) { std::unique_ptr result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - deadness_analysis_internal::PredicateMapTy predicate_map; - TF_ASSERT_OK(deadness_analysis_internal::ComputePredicates(*root.graph(), - &predicate_map)); + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); TensorId logical_and_output_0 = {logical_and.node()->name(), Graph::kControlSlot}; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f150bf1819d407e1c6a279673a89de4307b5426b..2788102620546d8eab657c519f078c5b03e265cc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -2504,7 +2504,8 @@ Status EncapsulateSubgraphsPass::Run( const int num_args = input_permutation->size(); std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr)); DataTypeVector arg_types(num_args); TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c0543a00792235c5dd090e81930d8c219dc7f1a3..7bc0ef030302dc6495e3e6d1151f458b450ed2c3 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, std::unordered_set control_input_a; std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (str_util::StartsWith(a.input(i), "^")) { - if (!str_util::StartsWith(b.input(i), "^")) { + if (absl::StartsWith(a.input(i), "^")) { + if (!absl::StartsWith(b.input(i), "^")) { if (diff) { *diff = strings::StrCat( diff_preamble, " mismatch for node ", a.name(), " input ", i, @@ -379,7 +379,7 @@ Node* InputShaped(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTestShaped", opts); } -Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, +Node* KnownShapeBase(DataType dtype, absl::Span shape, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", @@ -394,7 +394,7 @@ Node* KnownShapeBase(DataType dtype, const gtl::ArraySlice& shape, .FinalizeBuilder(&node_builder); } -Node* KnownShape(const gtl::ArraySlice& shape, +Node* KnownShape(absl::Span shape, const GraphDefBuilder::Options& opts) { return KnownShapeBase(DT_FLOAT, shape, opts); } @@ -417,8 +417,7 @@ Node* KeyPlaceholder(const string& call_node, } Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, - const string& oc_cluster, - const gtl::ArraySlice& dtypes, + const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; string key = @@ -768,7 +767,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -813,7 +812,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -892,13 +891,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, @@ -1038,26 +1037,26 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1190,13 +1189,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1213,13 +1212,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); @@ -1364,13 +1363,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1386,13 +1385,13 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"G:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F2_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"i_0_retval", "I:o:0"}}); @@ -1495,13 +1494,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1579,13 +1578,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {}, - {{"Tinputs", gtl::ArraySlice({})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", - gtl::ArraySlice({shape_proto_expected})}, + absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}}, {"D"}}, }, @@ -1661,12 +1660,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1742,12 +1741,12 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"f_0_retval", "F:o:0"}}); @@ -1846,13 +1845,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"F:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -1955,13 +1954,13 @@ TEST(EncapsulateSubgraphsTest, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, }, {{"h_0_retval", "H:o:0"}}); @@ -2066,37 +2065,37 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}}, {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O3_host_compute"}, "XlaHostCompute", {"D:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({})}, {"ancestors", - gtl::ArraySlice({"outside_compilation_O1_host_compute", - "outside_compilation_O2_host_compute"})}, + absl::Span({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, {"key", "host_compute_channel_F1_O3"}, {"shape_inference_graph", ""}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}}, {"outside_compilation_O1_host_compute", "outside_compilation_O2_host_compute"}}}, @@ -2272,13 +2271,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"c:o:0"}, - {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, - {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, - {"ancestors", gtl::ArraySlice({})}, + {{"Tinputs", absl::Span({DT_FLOAT})}, + {"Toutputs", absl::Span({DT_FLOAT})}, + {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, - {"shapes", gtl::ArraySlice({})}, + {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}}, {"c"}}, }, diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 8f78c110cb15f3cbc0344d102764241996b0d7de..253a5d254792a19d98b75310ea6848f42597c0c7 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -29,16 +29,3 @@ cc_library( ], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - visibility = ["//tensorflow/compiler/jit:friends"], - deps = [ - "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc deleted file mode 100644 index bd4eefbc0bb960f8ddc1d238057e73a29a098f26..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace { - -// Inputs 2*N tensors, outputs the first N inputs. -// Logs errors if input tensor i and i + N are not (near) identical -// in any position. -class ParallelCheckOp : public OpKernel { - public: - explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - template - int CompareTensors(DataType dtype, const char* v0, const char* v1, - int64 num_elts, int input_idx) { - int failed = 0; - const T* p0 = reinterpret_cast(v0); - const T* p1 = reinterpret_cast(v1); - double rtol; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), - &rtol)) { - LOG(ERROR) << "can't convert parallel_check_rtol " - << flags->parallel_check_rtol << " to double"; - } - double atol; - if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), - &atol)) { - LOG(ERROR) << "can't convert parallel_check_atol " - << flags->parallel_check_atol << " to double"; - } - for (int i = 0; i < num_elts; ++i) { - bool ok = (p0[i] == p1[i]); - VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; - if (!ok) { - if (std::is_same::value || std::is_same::value) { - float tolerance = - std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); - T diff = p0[i] - p1[i]; - if (diff < 0) diff = 0 - diff; - ok = (diff <= tolerance); - } - if (ok) continue; - LOG(ERROR) << "Op " << name() << " fails equality at output " - << input_idx << " type " << DataTypeString(dtype) - << " element " << i << ": std_val=" << p0[i] - << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); - if (++failed > 10) break; - } - } - return failed; - } - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << name(); - const int num_pairs = ctx->num_inputs() / 2; - for (int i = 0; i < num_pairs; ++i) { - CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); - Tensor t0 = ctx->input(i); - Tensor t1 = ctx->input(i + num_pairs); - int64 num_elts = t0.NumElements(); - CHECK_EQ(num_elts, t1.NumElements()); - - // Compare inputs elementwise for near-exact equality. - const char* v0 = t0.tensor_data().data(); - const char* v1 = t1.tensor_data().data(); - int failed = 0; - switch (ctx->input_dtype(i)) { - case DT_INT32: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_INT64: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_FLOAT: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_DOUBLE: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_BOOL: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - default: - LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); - } - if (failed > 0) { - LOG(ERROR) << "check failed for " << name() << " output " << i - << " num_elts: " << num_elts; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (flags->parallel_check_failfast) { - LOG(QFATAL) << "failfast on first parallel-check failure"; - } - } else { - VLOG(1) << "check passed for " << name() << " output " << i - << " num_elts: " << num_elts; - } - - // Propagate the std value. - if (IsRefType(ctx->input_dtype(i))) { - ctx->forward_ref_input_to_ref_output(i, i); - } else { - ctx->set_output(i, ctx->input(i)); - } - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); -}; - -REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), - ParallelCheckOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 7f4370b5b07b249bc9cf1f2ecf4086de359be68c..b6f2f632f7155234c87a0ea16fdc1910a09ed139 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -57,18 +56,17 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, ->stream->parent() ->platform() ->id(); - } else { - platform_id_ = nullptr; + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) { + use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams(); + platform_id_ = xla_device_metadata_->platform()->id(); } } Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, XlaCompilationCache** cache) { - const XlaDevice::Metadata* metadata; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - if (s.ok()) { - *cache = new XlaCompilationCache(metadata->client(), - metadata->jit_device_type()); + if (xla_device_metadata_) { + *cache = new XlaCompilationCache(xla_device_metadata_->client(), + xla_device_metadata_->jit_device_type()); return Status::OK(); } @@ -117,18 +115,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata = nullptr; - Status s = XlaDevice::GetMetadata(ctx, &metadata); - bool allocate_xla_tensors = s.ok(); - bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams(); - - // Get the platform_id_ for XLA_* devices. - if (platform_id_ == nullptr) { - if (s.ok()) { - platform_id_ = metadata->platform()->id(); - } - } - std::map variables = SnapshotResourceVariables(ctx, resources_); @@ -146,7 +132,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // (which local_xla_allocator above uses) as on an XlaDevice, this is a // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a // real allocator to allocate real buffers. - if (allocate_xla_tensors) { + if (xla_device_metadata_) { xla_allocator = client->backend().memory_allocator(); } else { xla_allocator = &local_xla_allocator; @@ -163,8 +149,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - if (metadata) { - options.shape_representation_fn = metadata->shape_representation_fn(); + if (xla_device_metadata_) { + options.shape_representation_fn = + xla_device_metadata_->shape_representation_fn(); } const XlaCompiler::CompilationResult* kernel; @@ -176,22 +163,25 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { } XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; - // Optimization: don't resolve constants. If we resolve constants we never - // emit them on the device, meaning that if they are needed by a following - // computation the host has to transfer them. - compile_options.resolve_compile_time_constants = false; + // If we resolve constants we never emit them on the device, meaning that if + // they are needed by a following computation the host has to transfer + // them. Not resolving constants is expected to be faster than resolving + // constants. + compile_options.resolve_compile_time_constants = true; // Optimization: where possible, have the computation return a naked array // rather than a one-element tuple. compile_options.always_return_tuple = false; OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, &compile_options)); + &kernel, &executable, compile_options)); VLOG(1) << "Executing XLA Computation..."; XlaComputationLaunchContext launch_context( - client, xla_allocator, allocate_xla_tensors, use_multiple_streams); + client, xla_allocator, + /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr, + use_multiple_streams_); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index 8dfc4b382d51151b6383fe7dd75429f3124d39be..e0f10e981737ad60e2b785a235dcb7fe7d21a053 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ #include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -58,7 +59,9 @@ class XlaLocalLaunchBase : public OpKernel { DeviceType device_type_; NameAttrList function_; - se::Platform::Id platform_id_; + se::Platform::Id platform_id_ = nullptr; + bool use_multiple_streams_ = false; + const XlaDevice::Metadata* xla_device_metadata_ = nullptr; }; // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph @@ -81,4 +84,4 @@ class XlaLocalLaunchOp : public XlaLocalLaunchBase { } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LOCAL_LAUNCH_OP_H_ +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 90d5d56998c75d6e2c7a64e8516591153a26f82d..4e4abade3278089a1c7f8fdee46a34b8ce503651 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" @@ -39,7 +41,10 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" namespace tensorflow { @@ -72,18 +77,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } +bool HasResourceOutput(const Node& node) { + return std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +bool HasResourceInput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end(); +} + +// Returns true if `node` is a resource operation recognized by tf2xla that +// operates on something other than resource variables. +bool IsNonResourceVarResourceOp(const Node& node) { + // TODO(b/112837194): We can't cluster these because we only support + // snapshotting resource variables (and we can't e.g. snapshot stacks). This + // limitation may be fixable with some work. + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; +} + // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { const NameAttrList* name_attr; NodeDef call; @@ -98,7 +125,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop condition: " << cond_func; return false; @@ -113,7 +141,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop body: " << body_func; return false; @@ -125,7 +154,8 @@ bool IsCompilableWhile(const Node& while_node, // Every operator in the function must be compilable for a function to be // compilable. bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { if (depth > kMaxRecursionDepth) { VLOG(2) << "Rejecting " << call_def.op() @@ -141,6 +171,10 @@ bool IsCompilableCall(const NodeDef& call_def, << ": could not instantiate: " << status; return false; } + + auto release_handle_on_return = gtl::MakeCleanup( + [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); CHECK(fbody); const FunctionDef& fdef = fbody->fdef; @@ -161,12 +195,17 @@ bool IsCompilableCall(const NodeDef& call_def, if (node->type_string() == "_Arg" || node->type_string() == "_Retval") continue; if (node->type_string() == "While") { - // Handle functional While loop (not in open source build). - return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); + // Handle functional While loop. + return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, + depth + 1, lib_runtime); + } + if (!allow_resource_ops && + (HasResourceInput(*node) || HasResourceOutput(*node))) { + return false; } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, depth + 1, - lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, + depth + 1, lib_runtime)) { VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " << node->name() << ": " << node->def().ShortDebugString(); return false; @@ -337,6 +376,10 @@ Status FindCompilationCandidates( flib_def, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + std::vector compile_time_const_nodes(graph.num_node_ids(), false); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes)); int64& fuel = legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; @@ -380,19 +423,46 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, + registration->compile_resource_ops, 0, lib_runtime)) { VLOG(2) << "Rejecting " << node->name() << ": unsupported op " << node->type_string(); continue; } if (!registration->compile_resource_ops && - HasResourceInputOrOutput(*node)) { - VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { + // We don't have a way of returning values of type DT_RESOURCE from XLA + // computations so we avoid auto-clustering nodes producing DT_RESOURCE. + // XlaLaunchOp also cannot snapshot resources that are not resource + // variables so we avoid clustering resource operations that operate on + // non-resource variables. + VLOG(2) << "Rejecting: " << node->name() << ": resource output " << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()] && + !registration->requires_compilation) { + const OpDef* op_def; + TF_RETURN_IF_ERROR( + OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + if (op_def->is_stateful()) { + // We need to be able to constant fold the nodes in + // compile_time_const_nodes given constant inputs (required by XLA) and + // therefore can't auto-cluster stateful ops since these can never be + // constant folded. + VLOG(2) << "Rejecting " << node->name() + << ": must-be-constant stateful op"; + continue; + } + } + // We don't auto-cluster functional control flow nodes containing resource + // operations because safety checks are trickier in this case. + // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not + // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, + registration->compile_resource_ops, 0, + lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. @@ -412,6 +482,31 @@ Status FindCompilationCandidates( return Status::OK(); } +// Determine the global jit level which is ON if either the +// GraphOptimizationPassOptions has the jit ON, or if the --tf_xla_auto_jit flag +// is true. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevel( + const GraphOptimizationPassOptions& options) { + OptimizerOptions::GlobalJitLevel global_jit_level = + options.session_options->config.graph_options() + .optimizer_options() + .global_jit_level(); + if (global_jit_level == OptimizerOptions::DEFAULT) { + // To set compilation to be on by default, change the following line. + global_jit_level = OptimizerOptions::OFF; + } + legacy_flags::MarkForCompilationPassFlags* flags = + legacy_flags::GetMarkForCompilationPassFlags(); + if (flags->tf_xla_auto_jit == -1 || + (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { + // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides + // the setting in ConfigProto. + global_jit_level = + static_cast(flags->tf_xla_auto_jit); + } + return global_jit_level; +} + struct Cluster { // Identifies the node that represents this cluster in the cycle detection // graph. @@ -426,7 +521,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - return IsCompilableCall(ndef, jit_device_type, 0, flr); + + // We can always *compile* resource operations, even if we are sometimes + // unable to auto-cluster them. + const bool compile_resource_ops = true; + return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); } Status MarkForCompilationPass::Run( @@ -434,22 +533,9 @@ Status MarkForCompilationPass::Run( // TODO(phawkins): precompute the "GetCompilationDevice" properties of each // device ahead of time. OptimizerOptions::GlobalJitLevel global_jit_level = - options.session_options->config.graph_options() - .optimizer_options() - .global_jit_level(); - if (global_jit_level == OptimizerOptions::DEFAULT) { - // To set compilation to be on by default, change the following line. - global_jit_level = OptimizerOptions::OFF; - } + GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); - if (flags->tf_xla_auto_jit == -1 || - (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { - // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides - // the setting in ConfigProto. - global_jit_level = - static_cast(flags->tf_xla_auto_jit); - } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; bool fusion_only = flags->tf_xla_fusion_only; @@ -517,9 +603,9 @@ Status MarkForCompilationPass::Run( bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; bool should_compile = (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + global_jit_level != OptimizerOptions::OFF; if (!should_compile) { - if (global_jit_level <= 0) { + if (global_jit_level == OptimizerOptions::OFF) { VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; } else { VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; @@ -530,6 +616,136 @@ Status MarkForCompilationPass::Run( return RunImpl(options, is_compilable); } +static string RatioToString(int numerator, int denominator) { + return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + (100.0 * numerator) / denominator); +} + +static void VLogClusteringSummary(const Graph& g) { + if (!VLOG_IS_ON(2)) { + return; + } + + std::map cluster_name_to_size; + std::map> + cluster_name_to_op_histogram; + std::map unclustered_op_histogram; + int clustered_node_count = 0; + + for (Node* n : g.nodes()) { + absl::optional cluster_name = GetXlaClusterForNode(*n); + if (cluster_name) { + clustered_node_count++; + cluster_name_to_size[*cluster_name]++; + cluster_name_to_op_histogram[*cluster_name][n->type_string()]++; + } else { + unclustered_op_histogram[n->type_string()]++; + } + } + + int unclustered_node_count = g.num_nodes() - clustered_node_count; + + VLOG(2) << "*** Clustering info for graph of size " << g.num_nodes(); + VLOG(2) << " Built " << cluster_name_to_size.size() << " clusters, size " + << RatioToString(clustered_node_count, g.num_nodes()); + + for (const auto& cluster_name_size_pair : cluster_name_to_size) { + StringPiece cluster_name = cluster_name_size_pair.first; + int size = cluster_name_size_pair.second; + VLOG(2) << " " << cluster_name << " " + << RatioToString(size, g.num_nodes()); + for (const auto& op_count_pair : + cluster_name_to_op_histogram[cluster_name]) { + VLOG(3) << " " << op_count_pair.first << ": " << op_count_pair.second + << " instances"; + } + } + + if (!unclustered_op_histogram.empty()) { + VLOG(2) << " Unclustered nodes: " + << RatioToString(unclustered_node_count, g.num_nodes()); + for (const auto& pair : unclustered_op_histogram) { + VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; + } + } + + struct EdgeInfo { + StringPiece node_name; + absl::optional cluster_name; + + StringPiece GetClusterName() const { + return cluster_name ? *cluster_name : "[none]"; + } + + std::pair> AsPair() const { + return {node_name, cluster_name}; + } + + bool operator<(const EdgeInfo& other) const { + return AsPair() < other.AsPair(); + } + }; + + using EdgeInfoMap = std::map>; + + EdgeInfoMap incoming_edge_infos; + EdgeInfoMap outgoing_edge_infos; + + std::set cluster_names_to_print; + + for (const Edge* e : g.edges()) { + const Node* from = e->src(); + absl::optional from_cluster_name = GetXlaClusterForNode(*from); + + const Node* to = e->dst(); + absl::optional to_cluster_name = GetXlaClusterForNode(*to); + + if (to_cluster_name == from_cluster_name) { + continue; + } + + if (to_cluster_name) { + incoming_edge_infos[*to_cluster_name] + [EdgeInfo{from->name(), from_cluster_name}]++; + cluster_names_to_print.insert(*to_cluster_name); + } + + if (from_cluster_name) { + outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++; + cluster_names_to_print.insert(*from_cluster_name); + } + } + + VLOG(2) << "*** Inter-Cluster edges:"; + if (cluster_names_to_print.empty()) { + VLOG(2) << " [none]"; + } + + auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name, + const EdgeInfoMap& edge_info_map, + StringPiece desc) { + auto it = edge_info_map.find(cluster_name); + if (it != edge_info_map.end()) { + VLOG(2) << " " << it->second.size() << " " << desc << " edges"; + for (const auto& edge_info_count_pair : it->second) { + VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " " + << edge_info_count_pair.first.node_name << " # " + << edge_info_count_pair.second; + } + } else { + VLOG(2) << " No " << desc << " edges."; + } + }; + + for (StringPiece cluster_name : cluster_names_to_print) { + VLOG(2) << " ** Cluster " << cluster_name; + print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, + "incoming"); + print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos, + "outgoing"); + } +} + // Is 'node' an operator that consumes only the shape of its input, not the // data itself? static bool IsShapeConsumerOp(const Node& node) { @@ -537,6 +753,43 @@ static bool IsShapeConsumerOp(const Node& node) { node.type_string() == "Size"; } +static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. + + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n.assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *ignore = true; + } else { + *ignore = registration->compile_resource_ops; + } + return Status::OK(); +} + // Sequence number generator to ensure clusters have unique names. static std::atomic cluster_sequence_num; @@ -565,6 +818,8 @@ Status MarkForCompilationPass::RunImpl( GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -577,6 +832,8 @@ Status MarkForCompilationPass::RunImpl( worklist.push_back(&clusters[node->id()]); } + OptimizerOptions::GlobalJitLevel global_jit_level = + GetGlobalJitLevel(options); legacy_flags::MarkForCompilationPassFlags* flags = legacy_flags::GetMarkForCompilationPassFlags(); @@ -601,7 +858,7 @@ Status MarkForCompilationPass::RunImpl( string to_scope; for (int to : cycles.Successors(from)) { if (to >= graph->num_node_ids()) { - // Node is a "frame" node that is present only in the cycle detection + // Node is a fictitious node that is present only in the cycle detection // graph. No clustering is possible. continue; } @@ -616,13 +873,15 @@ Status MarkForCompilationPass::RunImpl( } // Look for an _XlaScope on both nodes. If both nodes have a // scope and the scopes do not match, do not cluster along this - // edge. If even one of the nodes lacks an _XlaScope attribute, + // edge. This restriction is overridden if the global_jit_level is ON. If + // even one of the nodes lacks an _XlaScope attribute, // then it is treated as a "bridge" and a cluster may be created // along it. We may want to restrict this behavior to require // all nodes marked with _XlaCompile=true to also have a // _XlaScope property set (and raise an error otherwise); but // for now we don't do this. - if (GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && + if (global_jit_level == OptimizerOptions::OFF && + GetNodeAttr(node_from->attrs(), kXlaScopeAttr, &from_scope).ok() && GetNodeAttr(node_to->attrs(), kXlaScopeAttr, &to_scope).ok() && from_scope != to_scope) { continue; @@ -718,6 +977,9 @@ Status MarkForCompilationPass::RunImpl( dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph, options.flib_def); } + + VLogClusteringSummary(*graph); + return Status::OK(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index a780d4a936a3b757495c26d337f19c80a67f343a..807ab51fd3c133b95915ea88e0bf99dbb8661452 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" @@ -26,11 +28,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -48,9 +50,35 @@ std::unordered_map GetClusters(const Graph& graph) { ids[node->name()] = cluster; } } + + if (VLOG_IS_ON(2)) { + VLOG(2) << "Clusters:"; + for (const auto& p : ids) { + VLOG(2) << " " << p.first << " -> " << p.second; + } + } return ids; } +gtl::FlatMap> GetClusterSets( + const Graph& g, std::vector* cluster_names = nullptr) { + CHECK(cluster_names == nullptr || cluster_names->empty()); + gtl::FlatMap> cluster_sets; + for (const auto& p : GetClusters(g)) { + cluster_sets[p.second].push_back(p.first); + } + for (auto& p : cluster_sets) { + if (cluster_names != nullptr) { + cluster_names->push_back(p.first); + } + std::sort(p.second.begin(), p.second.end()); + } + if (cluster_names != nullptr) { + std::sort(cluster_names->begin(), cluster_names->end()); + } + return cluster_sets; +} + TEST(XlaCompilationTest, Chains) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -199,7 +227,7 @@ TEST(XlaCompilationTest, FunctionCalls) { {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); FunctionDef noinline = compilable; noinline.mutable_signature()->set_name("NoInlineFn"); - AddAttr("_noinline", bool(true), noinline.mutable_attr()); + AddAttr("_noinline", static_cast(true), noinline.mutable_attr()); FunctionDefLibrary flib; *flib.add_function() = compilable; @@ -372,6 +400,44 @@ TEST(XlaCompilationTest, Loops) { EXPECT_EQ(0, clusters.size()); } +TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor()) + .WithAttr(kXlaScopeAttr, "ScopeA")); + Node* b = ops::UnaryOp( + "Relu", a, + builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); + ops::BinaryOp( + "MatMul", a, b, + builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def(graph->op_registry(), flib); + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, &flib_def, &session_options)); + auto clusters = GetClusters(*graph); + + // The computation is: C = A + relu(A) + // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. + // In this case, the GlobalJitLevel overrides the scopes to cluster while + // ignoring scopes. + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -463,38 +529,104 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } -REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); -REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); - namespace { +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} -class DummyOp : public XlaOpKernel { - using XlaOpKernel::XlaOpKernel; - void Compile(XlaOpKernelContext* ctx) override {} -}; - -REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); -REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), + var_handle, value_to_write); + return assign_op.operation.node(); +} +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} } // namespace -TEST(XlaCompilationTest, Resources) { +TEST(XlaCompilationTest, ResourcesClusteringAllowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + FixupSourceAndSinkEdges(root.graph()); std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - // We should not form clusters with resource ops by default. - Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); - Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); - ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } + TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::vector cluster_names; + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph, &cluster_names); + + ASSERT_EQ(cluster_sets.size(), 2); + + std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", + "ValueToAssignW0"}; + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); + + std::vector expected_clustered_nodes_b = { + "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; + ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -524,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.ToString(), - "Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(absl::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -693,5 +825,27 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, RandomShape) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, + ops::Const(root.WithOpName("minval"), 1), + ops::Const(root.WithOpName("maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["shape"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index a84b82e47923b2e7eec0e7eb848bd4377befbd07..65669877f732bad9e145da36a3aedeba611a0fe5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "tensorflow/core/public/session_options.h" namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( - std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + SessionOptions* session_options) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { @@ -26,11 +28,18 @@ namespace tensorflow { GraphOptimizationPassOptions opt_options; opt_options.graph = graph; + opt_options.session_options = session_options; opt_options.flib_def = flib_def; MarkForCompilationPass pass; return pass.RunImpl(opt_options); } +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + SessionOptions session_options; + return MarkForCompilation(graph, flib_def, &session_options); +} + /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph) { FunctionDefLibrary flib; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index b9a0531cb0e431a98d57a6d9a2e3e41b51e7b743..216baaf933dc1f7e694289eea5d23996b595f4d4 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -24,6 +24,11 @@ class MarkForCompilationPassTestHelper { // Runs the MarkForCompilation pass on `graph` after assigning all nodes in // `graph` to the CPU device. To make testing easier, ignores device // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + SessionOptions* session_options); + + // Like `MarkForCompilation` but creates a default SessionOptions. static Status MarkForCompilation(std::unique_ptr* graph, FunctionLibraryDefinition* flib_def); diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index c9e46bc1475aed0e35a48765ad70eef4362e8281..13804c6a0575b921839f99ef7d142e0871693b5a 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -10,10 +10,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 68ead39424c35c1ef0bcc92e57af7931c0c57462..a8f09bfa5034e020fe3448d8ecfe0f70605e14d2 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace { Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, - gtl::ArraySlice post_order) { + absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to // avoid the device-host copy we'd otherwise need. @@ -30,7 +30,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { - gtl::optional from_cluster = GetXlaClusterForNode(*n); + absl::optional from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; } @@ -79,8 +79,8 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, // Check if `dst` is in a different cluster, unclustered, or about to be // partially declustered (here we rely on the post-order traversal order). // If yes, decluster `n` to avoid the device-to-host memcpy. - gtl::optional dst_cluster = - result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst); + absl::optional dst_cluster = + result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); if (from_cluster != dst_cluster) { CHECK(result->insert(n).second); break; @@ -99,7 +99,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { } Node* dst = out_edge->dst(); - gtl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + absl::optional dst_cluster_name = GetXlaClusterForNode(*dst); if (dst_cluster_name != cluster_name) { out_edges_to_clone.push_back(out_edge); } diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 08a956e4c6478ff76a0fe8f1f60d94824daf535c..f61a955c222dd7ce11a177cd54bb8851a5400496 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ba4a5ef7399111e512da8c4966f5899ed828b17 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ALGORITHM OVERVIEW +// ================== +// +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// computes the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write +// dependencies. +// +// Specifically the result computed by this analysis contains the edge {W, R} +// iff all of these hold true: +// +// - In the graph (g - {edges from NextIteration to Merge}) there is a path +// from W to R. +// - IsEdgeSafe(W, R) == False [defined below] +// - W != R (note: some resource operations both read from and write to +// resource variables). +// +// The result is incorrect around loops because we ignore edges from +// NextIteration to Merge, but that should be fine because we don't cluster +// these edges. For instance, in: +// +// Init -----> Merge <-------+ +// | | +// v | +// Read | +// | | +// v | +// Write | +// | | +// v | +// NextIteration --+ +// +// we won't put (Read, Write) in the returned set. This is fine if +// auto-clustering can only cluster the Read->Write edge, but it is a problem if +// it clusters the Write->NextIteration->Merge->Read edges instead. The same +// problem is present for the functional version of the loop above. We rely on +// auto-clustering to not cluster control flow edges like NextIteration->Merge. +// This is enough to avoid the explicit-control-flow problem shown above. One +// way to think about this is that we only care about cases where two nodes, A +// and B, would normally have been put in the same cluster but cannot legally be +// in the same cluster because of resourcevar-dependencies. If A and B would +// normally have been put in the same cluster then all paths between A and B +// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo +// there could not have been a NextIteration->Merge edge between A and B since +// we don't cluster these edges. +// +// We also rely on auto-clustering to not cluster functional control flow nodes +// that contain resource operations. +// +// IMPLEMENTATION +// -------------- +// +// We traverse the graph minus backedges in reverse post order, mapping each +// node to the set of resource operation reaching that node. Since we visit +// producers before consumers, we can construct the set of reaching operations +// by taking the union of the operations reaching the input nodes. These +// "reaching resource operations" can then be used to create the pairs of +// incompatible nodes using `IsEdgeSafe`. + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { +// Returns true if `n` may call a function. +Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, + bool* out_result) { + if (flib_def->Contains(n.type_string())) { + *out_result = true; + } else { + *out_result = + std::any_of(n.def().attr().begin(), n.def().attr().end(), + [](const std::pair& name_attr_pair) { + return name_attr_pair.second.has_func(); + }); + } + + return Status::OK(); +} + +// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is +// not a resource operation recognized by XLA then sets `out_resource_op_kind` +// to nullopt. +Status XlaResourceOpKindForNode( + const Node& n, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + absl::optional* out_resource_op_kind) { + bool should_ignore = false; + if (resource_ops_to_ignore) { + TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); + } + if (should_ignore) { + *out_resource_op_kind = absl::nullopt; + return Status::OK(); + } + + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); + if (op_info) { + *out_resource_op_kind = op_info->kind(); + return Status::OK(); + } + + // We conservatively assume that functions will both read and write resource + // variables. In the future we may consider doing some form of + // inter-procedural analysis. + bool may_call_function; + TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); + if (may_call_function) { + *out_resource_op_kind = XlaResourceOpKind::kReadWrite; + } else { + *out_resource_op_kind = absl::nullopt; + } + + return Status::OK(); +} + +// Returns true if a control or data dependence from a TensorFlow operation of +// resource op kind `from` to a TensorFlow operation of resource op kind `to` +// can be represented by an XLA cluster and needs no special handling around +// auto-jit. +bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { + // XLA clusters forces all reads to happen before all writes, which means the + // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, + // Modify->Write, Read->Read, Write->Write. + // + // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write + // dependencies. + return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; +} + +using ResourceOp = std::pair; + +string ResourceOpToString(const ResourceOp& resource_op) { + return strings::StrCat( + resource_op.first, ": ", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); +} + +// A copy-on-write set used to store the set of ResourceOps reaching a node in a +// TensorFlow graph. +// +// TODO(sanjoy): It may be useful to pull this out into its own header at some +// point. +class ResourceOpSet { + private: + using Impl = gtl::FlatSet; + + public: + ResourceOpSet() = default; + + // Adds all ResourceOp s in `other` to this set. + void Add(const ResourceOpSet& other) { + CHECK(!frozen_); + if (other.impl_ == impl_) { + other.frozen_ = true; + return; + } + + if (!impl_) { + other.frozen_ = true; + impl_ = other.impl_; + return; + } + + for (ResourceOp resource_op : other) { + Add(resource_op); + } + } + + void Add(const ResourceOp& resource_op) { + CHECK(!frozen_); + if (!IsCopy() && Contains(resource_op)) { + // We can avoid the copy if the item we want to insert already exists. + return; + } + + EnsureIsCopied(); + impl_->insert(resource_op); + } + + Impl::const_iterator begin() const { + return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); + } + + Impl::const_iterator end() const { + return impl_ ? impl_->end() : GetEmptyImpl()->end(); + } + + bool Contains(const ResourceOp& resource_op) const { + return impl_ != nullptr && impl_->count(resource_op); + } + + private: + bool IsCopy() const { return storage_ != nullptr; } + + void EnsureIsCopied() { + if (storage_ == nullptr) { + storage_ = absl::make_unique(); + for (ResourceOp op : *this) { + storage_->insert(op); + } + impl_ = storage_.get(); + } + } + + static Impl* GetEmptyImpl() { + static Impl* empty_impl = new Impl; + return empty_impl; + } + + Impl* impl_ = nullptr; + std::unique_ptr storage_; + + // frozen_ is true if there is another set pointing to this set's impl_. We + // can no longer add elements to this set in that case since the sets pointing + // to this set expect the contents of this set to be stable. + mutable bool frozen_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); +}; + +string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector elements_debug_string; + std::transform(resource_op_set.begin(), resource_op_set.end(), + std::back_inserter(elements_debug_string), ResourceOpToString); + return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); +} + +string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { + return strings::StrCat( + "[", n.name(), ": ", n.type_string(), "(", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); +} +} // namespace + +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result) { + CHECK(result->empty()); + + std::vector rpo; + GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + auto resource_op_set_for_node = + absl::make_unique(g.num_node_ids()); + + const bool vlog = VLOG_IS_ON(2); + + for (Node* n : rpo) { + absl::optional op_kind; + TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( + *n, flib_def, resource_ops_to_ignore, &op_kind)); + + ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; + + // Merge the reaching resource operations for all the incoming edges to + // create the set of all possible resource ops reaching `n`. + for (const Edge* e : n->in_edges()) { + if (n->IsMerge() && e->src()->IsNextIteration()) { + // Ignore back-edges (see file comment). + continue; + } + + const ResourceOpSet& incoming_op_set = + resource_op_set_for_node[e->src()->id()]; + resource_op_set->Add(incoming_op_set); + } + + // Add to the "incompatible resource ops" set if necessary. + if (op_kind) { + for (ResourceOp incoming_op : *resource_op_set) { + if (IsEdgeSafe(incoming_op.second, *op_kind)) { + continue; + } + + if (vlog) { + VLOG(2) << "Unsafe edge: " + << NodeToString(*g.FindNodeId(incoming_op.first), + incoming_op.second) + << " -> " << NodeToString(*n, *op_kind); + } + result->push_back({incoming_op.first, n->id()}); + } + + resource_op_set->Add({n->id(), *op_kind}); + } + + if (vlog) { + VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); + } + } + + std::sort(result->begin(), result->end()); + CHECK(std::unique(result->begin(), result->end()) == result->end()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..ae8cfeecad9b9cd631db3e9865bb3c3ff28a2e48 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// returns the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// The restrictions are not transitive: it is fine to put A and C in the same +// cluster even if the returned set contains (A,B) and (B,C). +// +// In other words, if these pairs are seen as edges in an undirected graph of +// the nodes in `g` then auto-clustering is at least as constrained as the graph +// coloring problem on this graph. +// +// +// For instance if we auto-cluster all operations in this TensorFlow graph: +// +// ReadVariablepOp0 -> ReadVariableOp1 +// | +// v +// AssignVariableOp0 -> AssignVariableOp1 +// +// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the +// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for +// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads +// all the resource variables when the cluster starts executing without any +// particular ordering between them; same holds for the AssignVariableOp0 -> +// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will +// be respected by XlaLaunchOp though because all reads happen before all +// writes. +// +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// back-edges (i.e. the edges from NextIteration to Merge). +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// functional control flow nodes containing resource operations. +// +// If `resource_ops_to_ignore` is set then nodes for which it returns true are +// ignored (we pretend these nodes are not resource operations). +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54b547abcfea698fe79e81dce547ea7858ff829 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -0,0 +1,540 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +Node* MakeModify(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); + ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id), + var_handle, value_to_write); + return assign_add_op.operation.node(); +} + +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} + +Status ComputeIncompatiblePairs(Graph* g, + std::vector>* result) { + FixupSourceAndSinkEdges(g); + return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {}, + result); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) { + Scope root = Scope::NewRootScope().ExitOnError(); + + MakeRead(root, "R"); + MakeWrite(root, "W"); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair modify_read_pair = {modify->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair modify_write_pair = {modify->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_modify_pair = {write->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, modify); + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 2); + std::pair modify_write_pair = {modify->id(), write->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair modify_read_pair = {modify->id(), read->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, + Status* status) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + return graph->AddNode(call_node, status); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_read_edge = {call->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], call_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(read, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair read_call_edge = {read->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], read_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_write_edge = {call->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], call_write_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(write, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_call_edge = {write->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], write_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(symbolic_gradient, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair symbolic_gradient_read_edge = {symbolic_gradient->id(), + read->id()}; + EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(write, symbolic_gradient); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_symbolic_gradient_edge = {write->id(), + symbolic_gradient->id()}; + EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 5); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + std::pair write_0_write_1_pair = {write_0->id(), write_1->id()}; + std::pair read_0_read_1_pair = {read_0->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + root.graph()->AddControlEdge(write_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT); + Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL); + Output enter_value = + ops::internal::Enter(root.WithOpName("enter"), init_value, "fr"); + ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName("exit"), iv.output); + Output next_iteration = + ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true); + TF_ASSERT_OK( + root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)); + + Node* write = MakeWrite(root, "W"); + Node* read = MakeRead(root, "R"); + + root.graph()->AddControlEdge(iv.output.node(), write); + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, next_iteration.node()); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 0a025a1fc0b268963069a8c1a3be700040be3f8e..4f2fabd658330b8ab182e13e02ed0bca41641e46 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -185,14 +186,14 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } -gtl::optional GetXlaClusterForNode(const Node& node) { +absl::optional GetXlaClusterForNode(const Node& node) { const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); if (attr_value == nullptr) { - return gtl::nullopt; + return absl::nullopt; } Status s = AttrValueHasType(*attr_value, "string"); if (!s.ok()) { - return gtl::nullopt; + return absl::nullopt; } return attr_value->s(); } @@ -207,4 +208,27 @@ bool HasResourceInputOrOutput(const Node& node) { void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } + +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles) { + std::vector> unsafe_deps; + TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( + *graph, flib_def, resource_ops_to_ignore, &unsafe_deps)); + + // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are + // operations that interact with resource variables, must not be put in the + // same cluster. We enforce this constraint by creating a phantom node, X, + // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P + // and Q together since that would create a cycle with X. + + for (std::pair unsafe_dep : unsafe_deps) { + int phantom_node_id = cycles->NewNode(); + CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id)); + CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bff76da6f9bcb06170e5aeb111da8545a6d291f8..b0439a63ca6476b6b1d63e65308712270381dd9f 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -18,9 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -47,7 +47,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. -gtl::optional GetXlaClusterForNode(const Node& node); +absl::optional GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); @@ -55,6 +55,13 @@ void RemoveFromXlaCluster(NodeDef* node_def); // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); +// Adds edges to `cycles` to prevent clustering resource operations that cannot +// be legally clustered. +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 2cb351e1ecdb4523a8652886af156540e4736b18..65bbf3efe85ba30f44531ff6d54b041786dca0a5 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7140d47a9421ec73d0144e855b490f89569e6ae9..ef6b0e67d3c4007f86dc7eef89cacb4cea98fc15 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { return CompileImpl(options, function, constant_args, variable_args, ctx, compilation_result, executable, compile_options, false); } @@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); @@ -256,7 +256,7 @@ Status XlaCompilationCache::CompileImpl( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); @@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl( entry->compiled = true; if (compile_single_op) { - entry->compilation_status = compiler.CompileSingleOp( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - signature.name, ctx, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileSingleOp(compile_options, signature.name, ctx, args, + &entry->compilation_result); } else { entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + compile_options, function, args, &entry->compilation_result); } TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index fc5f008f4f52c32d97e680784082d0e7bcb7d8eb..10ad87e38cc4d614e869782329f84351bc3b1f0b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase { const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index dd84fb34c171f8d2174444ddd3b3b476e7142718..3ba48e8c318f84a4691fb74434bc009fdd0d81bf 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, &compile_options); + result, executable, compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 2a2691a6a404520da4df451293ec0cb6028a165d..f31879a2bc517d8b05e129cf0777196d0ee4dc79 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device_context.h" @@ -101,7 +102,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( } std::unique_ptr alloc = - xla::MakeUnique(); + absl::make_unique(); XlaDeviceAllocator* alloc_ptr = alloc.get(); state.allocators_[{backend, device_ordinal}] = std::move(alloc); return alloc_ptr; @@ -184,14 +185,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } -/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, - const Metadata** metadata) { +/*static*/ Status XlaDevice::GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata) { *metadata = nullptr; - XlaDevice* xla_device = - dynamic_cast(ctx->device()->UnderlyingDevice()); + XlaDevice* xla_device = dynamic_cast(device->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( - "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), + "Cannot get XLA metadata from non-XLA device \"", device->name(), "\". GetMetadata must only be called on an XLA device. Either an " "internal bug has been triggered, or an XLA-specific op has been " "placed on the wrong device."); @@ -200,6 +200,16 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } +/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + +/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata) { + return GetMetadataFromDevice(ctx->device(), metadata); +} + XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, @@ -327,7 +337,7 @@ xla::StatusOr XlaDevice::GetDeviceContextLocked() { // to those methods; see the bug for details. Our only saving grace at the // moment is that this race doesn't seem to occur in practice. if (use_gpu_device_info_) { - auto gpu_device_info = MakeUnique(); + auto gpu_device_info = absl::make_unique(); gpu_device_info->stream = stream_.get(); gpu_device_info->default_context = device_context_; set_tensorflow_gpu_device_info(gpu_device_info.get()); @@ -364,11 +374,7 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When Xprof profiling is off (which is the default), constructing the - // activity is simple enough that its overhead is negligible. - tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); - op_kernel->Compute(context); + TracingDevice::Compute(op_kernel, context); } void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index dbf35f349f84268ebac0f73a86c9ca0704e90835..92891ffa8c6e4a19623172574b17d90fd344c570 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -88,6 +88,10 @@ class XlaDevice : public LocalDevice { // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); + // Factory function. 'platform_name' is the name of the XLA platform. // 'device_name' is the name of the Tensorflow device to create. // 'jit_device_name' is the name of the corresponding JIT device. @@ -158,6 +162,9 @@ class XlaDevice : public LocalDevice { xla::StatusOr GetDeviceContextLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + static Status GetMetadataFromDevice(DeviceBase* device, + const XlaDevice::Metadata** metadata); + mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0a0c0892411e8ebcd5624a29f3bd020fe6483944..ee07c5c9643ef1119b9077326c1cf7c83930e90c 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -91,7 +91,8 @@ Status XlaTransferManager::TransferLiteralToDevice( const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); - if (UseMultipleStreams()) { + if (UseMultipleStreams() && !transfer_manager_->CanShapedBufferBeAccessedNow( + stream_->parent(), shaped_buffer)) { // Initially wait for the compute stream so that memory allocations are // synchronized. host_to_device_stream_->ThenWaitFor(stream_.get()); @@ -123,11 +124,11 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( device_to_host_stream_.get(), shaped_buffer, literal, - [=, &shaped_buffer, &literal](xla::Status status) { + [=, &shaped_buffer](xla::Status status) { ref.Unref(); done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " << literal.ToString() - << " " << shaped_buffer.ToString(); + VLOG(1) << "Transfer from device as literal: " + << shaped_buffer.ToString(); return status; }()); }); @@ -183,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback - // to avoid a deadlock. If done() is the callback that ends an - // Executor's run, the Executor may call XlaDevice::Sync() inside the - // callback. This deadlocks, because XlaDevice::Sync() waits for all - // stream activity to complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); - return; - } } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); @@ -207,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get(), block_status.error_message().c_str()); } } - xla_tensor->set_host_tensor(*cpu_tensor); - + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + } done(status); } diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index da3e329247e825d4a33a53dc310899d6ba6ce9cf..13da5d2f948df671df6d0d80687321eaaa923943 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -215,6 +215,8 @@ class XlaAssignVariableOp : public AsyncOpKernel { AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ + IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 4b499b161371ecece14447b29fbf809b6e8857db..07cfab615157650aea0e15cdafa8c9b0925f9e5f 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -41,8 +41,8 @@ static bool IsShapeConsumerOp(const Node& node) { } // Returns true if the op can be decomposed into XLA ops for which -// there are fusable elemental implementations. -bool IsXlaFusable(const NodeDef& node) { +// there are fusible elemental implementations. +static bool IsXlaFusible(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( {// tf2xla/kernels/aggregate_ops.cc @@ -176,9 +176,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); if (device_type.type_string().find("XLA") != string::npos) continue; - // Assume all fusable ops are registered. + // Assume all fusible ops are registered. // TODO(hpucha): Check for registration if possible. - if (!IsXlaFusable(node->def())) { + if (!IsXlaFusible(node->def())) { continue; } @@ -208,6 +208,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); // TODO(hpucha): Make clustering more robust. There are two known issues that // we need to mitigate: (a) Non-resource variables can cause deadlocks diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc index 5736760a878dc857a8558093054d0adc0f727398..68e19c8a135735a79fcabf121e619157fa22b4d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -71,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) { EXPECT_TRUE(clusters.find("D") == clusters.cend()); } -TEST_F(XlaFusionOptimizerTest, FusableOps) { +TEST_F(XlaFusionOptimizerTest, FusibleOps) { GraphDef graph; { GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); @@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } +TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output var_handle = + ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({})); + Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f); + Output begin = ops::Const(root.WithOpName("begin"), 0); + Output end = ops::Const(root.WithOpName("end"), 1); + Output strides = ops::Const(root.WithOpName("strides"), 1); + ops::ResourceStridedSliceAssign assign_1( + root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign); + ops::ResourceStridedSliceAssign assign_2( + root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign); + root.graph()->AddControlEdge(assign_1.operation.node(), + assign_2.operation.node()); + grappler::GrapplerItem item; + root.graph()->ToGraphDef(&item.graph); + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_NE(clusters["assign_1"], clusters["assign_2"]); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 4efbb2d5d7cf09d9cf1e35c8cf5403e7e0dfe733..affeab4a8c43b63ac0e2b8ef40de5223ce39d410 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -175,7 +176,7 @@ void XlaComputationLaunchContext::PopulateInputs( << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = xla::MakeUnique( + arg_buffers_[i] = absl::make_unique( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); @@ -270,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - se::DeviceMemoryBase buffer = output.buffer({output_num}); - if (allocate_xla_tensors_) { - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); - if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + const DataType& type = kernel->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_RESOURCE) { + ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } - } else { - Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); + ++output_num; } - ++output_num; } if (VLOG_IS_ON(3)) { diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 4232f514b3b48681bf510ee568f916f5f4ebe882..7ac275fab833400b90ced0180192845c9be30534 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -167,4 +167,4 @@ xla::ScopedShapedBuffer ExtractSubShapedBuffer( } // namespace tensorflow -#endif +#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 8d36d0fa0a8230bcd1b16cc67de104e09358144f..4c9bb2e27b0ca3c83848be7fdf189fdbad89cee5 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -70,7 +71,7 @@ class XlaTensor { // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = - xla::MakeUnique(std::move(shaped_buffer)); + absl::make_unique(std::move(shaped_buffer)); } // Some tensors on the device may have known values on the host. We use these @@ -127,4 +128,4 @@ class XlaTensor { } // namespace tensorflow -#endif +#endif // TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index ae98b3f0f9d5dac66b9716ad84a9f0371511e9b6..34defe1c7ade687a7524390cee78657e1a27f5b4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "medium", + size = "large", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -251,6 +251,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_triangular_solve_op_test", size = "small", + timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], tags = ["optonly"], deps = [ @@ -387,6 +388,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reshape_op_test", + size = "small", + srcs = ["reshape_op_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "@absl_py//absl/testing:parameterized", + ], +) + tf_xla_py_test( name = "dynamic_stitch_test", size = "small", @@ -559,6 +573,7 @@ tf_xla_py_test( tf_xla_py_test( name = "matrix_band_part_test", size = "medium", + timeout = "long", srcs = ["matrix_band_part_test.py"], tags = ["optonly"], deps = [ @@ -715,6 +730,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -1177,3 +1193,19 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "xla_ops_test", + size = "small", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index 3e3c09c66e72c4de141b64cea3c4693fabb7b2a2..b7b7fda293b69d6f0cec61d0d234277636a3670d 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -33,7 +33,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for grad in [0.2, 0.1, 0.01]: for lr in [1.0, 0.5, 0.1]: var0_init = [1.0, 2.0] diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index dc1625793aa44b96d3b96e175237caf96e7d7e74..69fb3ec2964a09508e612515b9e291fc14121d68 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithoutRegularizationBasic1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index d775850a80e9f83f7b2c9f1cf8997dd50e229635..ab69319c59fb07e7ce56c3c287a50a6290effdfd 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -57,7 +57,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -83,7 +83,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 0d2e4d029636577adc74784d9a8b3494b94dc67d..df0f21471a1c67e69e037f6409bcab1297d3399d 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -53,7 +54,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -95,7 +96,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -137,7 +138,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index c4fdbc5974319db9243eb2c323746cbaaea795f6..3ed1d41b7121f44dd7470f61180f7a7055369174 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testBasic(self): for i, dtype in enumerate(self.float_types): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -100,7 +100,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 9ec5a964cbb4dd98d2ef2d0b684872292118800f..1bc07ace23ccdc83103abe71ee11b72994c75a6d 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase): alpha=1.0, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 9d3a889b1f54c813e881bb03b5275f809af1b3c8..4155342787fbbdeaf5c5958c44d007b1ea0660ed 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase): op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0aafda7fb4d710f154157ee352d6616e5aa8935f..17280e445b329d1541aaed78ec106f8f282cbc74 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -36,7 +36,7 @@ class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") @@ -1010,7 +1010,38 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) - def testMirrorPad(self): + def testSymmetricMirrorPad(self): + mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") + for dtype in self.numeric_types: + self._testBinary( + mirror_pad, + np.array( + [ + [1, 2, 3], # + [4, 5, 6], # + ], + dtype=dtype), + np.array([[ + 2, + 2, + ], [3, 3]], dtype=np.int32), + expected=np.array( + [ + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + ], + dtype=dtype)) + self._testBinary( + mirror_pad, + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[0, 0], [0, 0]], dtype=np.int32), + expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + + def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: self._testBinary( @@ -1165,6 +1196,16 @@ class BinaryOpsTest(xla_test.XLATestCase): def testTile(self): for dtype in self.numeric_types: + self._testBinary( + array_ops.tile, + np.array([[6], [3], [4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([6, 0], dtype=dtype)) + self._testBinary( + array_ops.tile, + np.array([[6, 3, 4]], dtype=dtype), + np.array([2, 0], dtype=np.int32), + expected=np.empty([2, 0], dtype=dtype)) self._testBinary( array_ops.tile, np.array([[6]], dtype=dtype), @@ -1362,5 +1403,40 @@ class BinaryOpsTest(xla_test.XLATestCase): [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], dtype=dtype)) + def testBroadcastTo(self): + for dtype in self.all_types: + x = np.random.randint(0, high=100, size=[2, 3]) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([2, 3], dtype=np.int32), + expected=x) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([6, 6], dtype=np.int32), + expected=np.tile(x, [3, 2])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 4, 3], dtype=np.int32), + expected=np.tile(x, [7, 2, 1])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 0, 3], dtype=np.int32), + expected=np.zeros([7, 0, 3], dtype=dtype)) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 1, 2, 9], dtype=np.int32), + expected=np.tile(x, [7, 1, 1, 3])) + self._testBinary( + array_ops.broadcast_to, + np.zeros([2, 0], dtype=dtype), + np.array([4, 0], dtype=np.int32), + expected=np.zeros([4, 0], dtype=dtype)) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index ef4d5f6322b7ae79b051795b5af7e6f7f1e55550..5c24db539bce5df701d8229290ddb4c20997d40a 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) def testFloat(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) @@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) def test2DInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase): {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) def testInvalidBoundariesOrder(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) @@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0]}) def testBoundariesNotList(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Expected list.*"): p = array_ops.placeholder(dtypes.int32) with self.test_scope(): diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9..a57d1dc81ea2c9c188b0a3005904738aa8156bf3 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -56,7 +56,7 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) @@ -79,7 +79,7 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype, output_dtype) @@ -107,7 +107,7 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index ed532db0ee5553a275192e6cc3ebf394075fa0e1..d1896a50f7037f2972cba8a4fa16cc1e2cd4fe3e 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase): def _verifyCholesky(self, x, atol=1e-6): # Verify that LL^T == x. - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder( dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b..88bd58b2da6b2892f898ad10f3467d8ce39d6388 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1], dtype=np.float32) val2 = np.array([5, 6, 7, 8], dtype=np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with self.test_scope(): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1]).astype(np.float32) val2 = np.array([5, 6, 7, 8]).astype(np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with ops.device(CPU_DEVICE): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase): # where x and z are placed on the CPU and y and w are placed on the XLA # device. If y and w are clustered for compilation, then the graph will # deadlock since the clustered graph will contain a self-loop. - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device(CPU_DEVICE): x = array_ops.placeholder(dtypes.float32, [2]) with self.test_scope(): @@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase): self.assertAllClose(result, [12., 2.], rtol=1e-3) def testHostMemory(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.int32) with self.test_scope(): y = x + 1 diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index d9ad4281477e87f79f2ecb52989ae86a5030d0cc..37e5318bb54c5d8ecdedc7bb346e89765f2adf35 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest class ConcatTest(xla_test.XLATestCase): def testHStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[4:, :], params[p2]) def testVStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[:, 4:], params[p2]) def testInt32(self): - with self.test_session(): + with self.cached_session(): p1 = np.random.rand(2, 3).astype("i") p2 = np.random.rand(2, 3).astype("i") x1 = constant_op.constant(p1) @@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase): dtype_feed = dtypes.float32 else: dtype_feed = dtype - with self.test_session(): + with self.cached_session(): p = [] for i in np.arange(num_tensors): input_shape = shape @@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase): self._testRandom(dtypes.int32) def _testGradientsSimple(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsSimple() def _testGradientsFirstDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsFirstDim() def _testGradientsLastDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase): # Random dim to concat on concat_dim = np.random.randint(5) concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase): def testConcatTuple(self): c1 = np.random.rand(4, 4).astype(np.float32) c2 = np.random.rand(4, 4).astype(np.float32) - with self.test_session(): + with self.cached_session(): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) def testConcatNoScalars(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): scalar = constant_op.constant(7) dim = array_ops.placeholder(dtypes.int32) @@ -295,7 +295,7 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) @@ -309,7 +309,7 @@ class ConcatOffsetTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) @@ -319,7 +319,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) @@ -329,7 +329,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index f9db103f6d0f9ea0e393a0971593552ec5c14079..af00ff287d43a8542b5a3d14eedc00c3d7aef1b7 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): @@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): @@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 31ee41f04f27d387415e9fa2c4fa70b33cab7b04..33fd983b5485e503c2fcc96db2dfdecfc41e309f 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for padding in ["SAME", "VALID"]: for stride in [1, 2]: np.random.seed(1) @@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 1, 1, 1, 1] # Input, output: [batch, depth, height, width, channel] @@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeSame(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeValid(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): np.random.seed(1) # Make it reproducible. x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 865f60ccab46ec6829e49409508303052944e13b..04f3b3ef4905984b0432a536c3b1c275738ede17 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -86,7 +86,7 @@ class DenseLayerTest(test.TestCase): XlaLaunch op by XLA. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -113,7 +113,7 @@ class DenseLayerTest(test.TestCase): cluster, causing dense layer to be split into TWO XlaLaunch ops. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 98dc73e189f99b7b811487756659d89dacb97d8a..6ef8a68ca5d35d3d2f78f0cb491e7bb98ff97ac9 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=data_type).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=data_type).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: if data_type == np.float32: tolerance = 1e-4 else: @@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=np.float32).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=np.float32).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32) t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32) with self.test_scope(): @@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t1 = array_ops.placeholder(np.float32, shape=filter_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes) @@ -356,7 +356,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 154e36b10e6da409606ae6022aaf53e34c8e37cc..5f01e128f0b0fa725d99b00ba3406bd50a1b8962 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index edd78153b56bb5bf1c268936fb82a60581389733..50b04daa6b9f4159a3c4bdeecaf900a5b35a833c 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): - with self.test_session() as session: + with self.cached_session() as session: index_placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices ] diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index ff097f80f1f2586bd483a54d532750c90b2a8b03..63cee550fde9d9d4314b1541fba191df776a4da2 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -101,7 +101,7 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) @@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f(v) self.assertEqual(2.0, var.numpy()) + def testReturnResourceHandle(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) + + def f(v): + return v.handle + + f = function.defun(f) + handle = f(v) + self.assertAllEqual(v.numpy(), + resource_variable_ops.read_variable_op( + handle, dtypes.float32).numpy()) + + def testReturnMultipleResourceHandles(self): + with self.test_scope(): + v1 = resource_variable_ops.ResourceVariable(1.25) + v2 = resource_variable_ops.ResourceVariable(2.0) + + def f(v): + return v.handle, 3.0 * v, v2.handle, v + v2 + + f = function.defun(f) + v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) + self.assertAllEqual(v1.numpy(), + resource_variable_ops.read_variable_op( + v1_handle, dtypes.float32).numpy()) + self.assertEqual(3.75, v1_times_3.numpy()) + self.assertAllEqual(v2.numpy(), + resource_variable_ops.read_variable_op( + v2_handle, dtypes.float32).numpy()) + self.assertEqual(3.25, variable_sum.numpy()) + def testAllArgumentKinds(self): """Test a complex function that takes different argument kinds. @@ -443,7 +475,6 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertAllEqual((2, 3, 4), dz.shape.as_list()) def testNestedDefun(self): - self.skipTest('Nested defuns do not work on TPU at the moment') with self.test_scope(): @function.defun @@ -458,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase): y = two_x_plus_1(x) self.assertAllEqual([5, 7, 9], y.numpy()) + def testNestedDefunWithVariable(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + y = f(x) + + self.assertEqual(75, y.numpy()) + + def testNestedDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + def testNestedDefunInGradientTapeDifferentVars(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + v1 = resource_variable_ops.ResourceVariable(3.0) + + @function.defun + def g(x): + x = v1 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape(persistent=True) as tape: + y = f(x) + dy_v0 = tape.gradient(y, v0) + dy_v1 = tape.gradient(y, v1) + + self.assertEqual(45, y.numpy()) + self.assertEqual(9, dy_v0.numpy()) + self.assertEqual(15, dy_v1.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 5529fdbb090315e1d7f47589777d8a538c90db2b..37061e91d161db352b388a965eb72c9c32d3d752 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase): strides = [1] + strides + [1] rates = [1] + rates + [1] - with self.test_session(): + with self.cached_session(): image_placeholder = array_ops.placeholder(dtypes.float32) with self.test_scope(): out_tensor = array_ops.extract_image_patches( diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index c48ab178bf53558084fb500b2811c6f0b77a7943..2178c4455609550226c89ceb185837768be1f622 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") @@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): expected_backprops_wrt_min = 1.0 + 2.0 expected_backprops_wrt_max = 10.0 + 11.0 - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index c64ea249ecb97991952a960a6d16e1bb3be35b17..b3e13fbaa6b33bdaa1be123be558059e96de282e 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -71,7 +71,7 @@ class FFTTest(xla_test.XLATestCase): data = np.reshape(data.astype(np.float32).view(np.complex64), shape) data = to_32bit(complex_to_input(data)) expected = to_32bit(input_to_expected(data)) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) @@ -93,7 +93,7 @@ class FFTTest(xla_test.XLATestCase): data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] expected = np.swapaxes(expected, -1, -2) expected *= window.sum() # scipy divides by window sum - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 0f64cc87cde77fbbef6c4e570879e992bc34bafa..8c7edfd277c992c35a81dd5f261256a86352254e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -31,13 +31,13 @@ from tensorflow.python.platform import test class FIFOQueueTest(xla_test.XLATestCase): def testEnqueue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) enqueue_op.run() def testEnqueueWithShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) enqueue_correct_op.run() @@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual(1, q.size().eval()) def testMultipleDequeues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue([1])) self.evaluate(q.enqueue([2])) @@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) def testQueuesDontShare(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) @@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) with self.assertRaisesRegexp(ValueError, "must have names"): q.enqueue({"a": 12.0}) def testParallelEnqueue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testParallelDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elem], result) def testMultiEnqueueAndDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) elems = [(5, 10.0), (10, 20.0), (15, 30.0)] enqueue_ops = [q.enqueue((x, y)) for x, y in elems] @@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([y], y_val) def testQueueSizeEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) self.assertEqual([0], q.size().eval()) def testQueueSizeAfterEnqueueAndDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) dequeued_t = q.dequeue() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b..f1b87a5ffb73bed62a80abaa152d335f64d970c5 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -29,7 +29,6 @@ from tensorflow.python.training import adagrad from tensorflow.python.training import ftrl from tensorflow.python.training import gradient_descent - class FtrlOptimizerTest(xla_test.XLATestCase): def initVariableAndGradient(self, dtype): @@ -112,7 +111,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -146,7 +145,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -174,7 +173,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -196,13 +195,17 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + np.array([-7.66718769, -10.91273689]), + var0.eval(), + rtol=1e-4, + bfloat16_rtol=1e-1, + bfloat16_atol=1e-1) self.assertAllCloseAccordingToType( np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -236,7 +239,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): weights will tend to have smaller magnitudes with this parameter set. """ for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -259,9 +262,49 @@ class FtrlOptimizerTest(xla_test.XLATestCase): # Validate updated params self.assertAllCloseAccordingToType( - np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) self.assertAllCloseAccordingToType( - np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) + np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) + + def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): + """Verifies that l2 shrinkage in FTRL does not change lr schedule.""" + for dtype in self.float_types: + with self.test_session(), self.test_scope(): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) + grads1 = constant_op.constant([0.1, 0.2], dtype=dtype) + + opt0 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0, + l2_shrinkage_regularization_strength=0.1) + opt1 = ftrl.FtrlOptimizer( + 3.0, + initial_accumulator_value=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=2.0) + update0 = opt0.apply_gradients([(grads0, var0)]) + update1 = opt1.apply_gradients([(grads1, var1)]) + variables.global_variables_initializer().run() + + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) + + # Run 10 steps FTRL + for _ in range(10): + update0.run() + update1.run() + + # var0 is experiencing L2 shrinkage so it should be smaller than var1 + # in magnitude. + self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) + accum0 = list(opt0._slots["accum"].values())[0].eval() + accum1 = list(opt1._slots["accum"].values())[0].eval() + # L2 shrinkage should not change how we update grad accumulator. + self.assertAllCloseAccordingToType(accum0, accum1) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical @@ -273,9 +316,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivAdagradwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) @@ -284,9 +327,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivGradientDescentwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 04fba444460e714ce96205361ac02ed492206b04..b1891b918c6584abce9da382088ed0037f5319fb 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase): def testCompileTimeConstantsInDefun(self): """Tests that XLA handles compile-time constants in defuns.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) def Foo(a, c, d): @@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = aval + bval * 2 - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtypes.float32, name="a") b = array_ops.placeholder(dtypes.float32, name="b") diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 132e42ac7a28d0769b0de12ea0cee6eae752b245..8c018cccb83a05babb0b7f73b80b4f9de7267c98 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -83,7 +83,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -126,7 +126,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -210,7 +210,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( @@ -260,7 +260,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format_src = "NHWC" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 23b0aed34fb460f50c241e5a920cb4f6f613b947..7161f4ab339b6f4069dd2b02ddbc6a89973e0074 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): - with self.test_session(): + with self.cached_session(): paramsp = array_ops.placeholder(params.dtype) indicesp = array_ops.placeholder(indices.dtype) with self.test_scope(): @@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([[4], [4], [0]], np.int32))) def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): - with self.test_session(): + with self.cached_session(): params = np.ones((3, 3), dtype=np.float32) indices_empty = np.empty((0, 2), dtype=np.int32) diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3..089d95daab7e502b4ba13796fadc2ba3f209759b 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase): return data def testScalar1D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in self.all_tf_types: for indices in 4, [4], [1, 2, 2, 4, 5]: @@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(np_val, gather_val) def testScalar2D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -69,7 +69,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -87,7 +87,7 @@ class GatherTest(xla_test.XLATestCase): if np.int64 not in self.int_types: return - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) # The indices must be in bounds for any axis. @@ -114,7 +114,7 @@ class GatherTest(xla_test.XLATestCase): for axis in 0, 1, 2, 3, -1, -2: params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): tf_params = array_ops.placeholder(dtype=dtype) tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) @@ -123,7 +123,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): - with self.test_session(): + with self.cached_session(): for dtype in self.numeric_tf_types: params = array_ops.placeholder(dtype=dtype) indices = array_ops.placeholder(dtype=np.int32) @@ -137,7 +137,7 @@ class GatherTest(xla_test.XLATestCase): [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) def testGatherPrecision(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) indices = np.array([1, 2, 3, 1]) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index bf986ade06b11358552ee92df3169f965ce3f534..6fe5a66e0e6717ec738dded9196eef6ba1e2114d 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -54,7 +54,7 @@ class RGBToHSVTest(xla_test.XLATestCase): inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually - with self.test_session() as sess: + with self.cached_session() as sess: batch0 = array_ops.placeholder(nptype, shape=shape) with self.test_scope(): batch1 = image_ops.rgb_to_hsv(batch0) @@ -78,7 +78,7 @@ class RGBToHSVTest(xla_test.XLATestCase): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] for nptype in self.float_types: rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv = image_ops.rgb_to_hsv(placeholder) @@ -97,7 +97,7 @@ class RGBToHSVTest(xla_test.XLATestCase): for r, g, b in rgb_flat ]) hsv_np = hsv_np.reshape(4, 4, 4, 3) - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv_op = image_ops.rgb_to_hsv(placeholder) @@ -108,7 +108,7 @@ class RGBToHSVTest(xla_test.XLATestCase): class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -146,7 +146,7 @@ class AdjustContrastTest(xla_test.XLATestCase): return y_np def _adjustContrastTf(self, x_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(np.float32) with self.test_scope(): y = image_ops.adjust_contrast(x, contrast_factor) @@ -180,7 +180,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -198,7 +198,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -216,7 +216,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -244,7 +244,7 @@ class AdjustHueTest(xla_test.XLATestCase): return y_v.reshape(x_np.shape) def _adjustHueTf(self, x_np, delta_h): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtypes.float32) with self.test_scope(): y = gen_image_ops.adjust_hue(x, delta_h) @@ -324,7 +324,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -339,7 +339,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -378,7 +378,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): "gb_same", "rgb_same", ] - with self.test_session(): + with self.cached_session(): for x_shape in x_shapes: for test_style in test_styles: x_np = np.random.rand(*x_shape) * 255. @@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): image_np, target_shape, expected=None, - large_tolerance=False): + large_tolerance=False, + align_corners=True): if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_bilinear( - image, target_shape, align_corners=True) + image, target_shape, align_corners=align_corners) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) if large_tolerance: self.assertAllClose( @@ -433,7 +434,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.fail("input_shape must be specified") if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) resized = gen_image_ops.resize_bilinear_grad( @@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase): dtype=np.float32)), large_tolerance=True) + def testNonAlignCorners3x2To6x4(self): + input_data = [[64, 32], [32, 64], [50, 100]] + expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [6, 4], + expected=np.array(expected_data, dtype=np.float32), + align_corners=False) + + def testNonAlignCorners6x4To3x2(self): + input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127], + [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]] + expected_data = [[127, 64], [64, 127], [50, 100]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [3, 2], + expected=np.array(expected_data, dtype=dtype), + align_corners=False) + class NonMaxSuppressionTest(xla_test.XLATestCase): @@ -596,7 +618,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -639,7 +661,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -686,7 +708,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 45a04f0cf56e88946b946bedacb25ce6da3121b4..58622114e4f552fb71db9b040a39b57d7da0037c 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.test_session() as sess: + with self.cached_session() as sess: x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 253b45902fba2df64e5234f135b373cd2a0a7e2a..c6ad67993e8bc196a74c9a328df8c9200c92c575 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase): return output def _RunAndVerify(self, dtype): - with self.test_session(): + with self.cached_session(): # random shape shape = np.random.randint(1, 16, size=4) # Make depth at least 2 to make it meaningful @@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase): alpha = 1.0 * np.random.rand() beta = 1.0 * np.random.rand() - with self.test_session(): + with self.cached_session(): in_image = constant_op.constant(in_image_vals, shape=shape) out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 31093c65713df55390c3130b8654fdcb10fbc133..265c0b6d1412de7be3a5bf5e79129cb330ceb162 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -73,7 +73,7 @@ class LSTMTest(test.TestCase): def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 @@ -156,7 +156,7 @@ class LSTMTest(test.TestCase): def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 seq_length = 3 diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 0d9f99f8a6803ecae5f9233518a1768109161ac0..9222db4b7ebf020c8cee1c0af81e05129fb33c4d 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): - with self.test_session(): + with self.cached_session(): batch_shape = shape[:-2] mat = np.ones(shape).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 2bb8a97bdaf5836a05501ab9754433e29ae34675..94cd3eeb3179da9b920ea9f03216d602b042a639 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) - with self.test_session() as sess: + with self.cached_session() as sess: placeholder_a = MakePlaceholder(a) placeholder_ca = MakePlaceholder(clean_a) placeholder_b = MakePlaceholder(b) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078..f77521a7c49dba39849869ddceb7c0e885147722 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -95,7 +95,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) var0_np = np.array([0.1, 0.2], dtype=dtype) @@ -120,7 +120,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index da08225e9fc0d5a8ec21ee9961c4758fa38628b4..a1c07fce732d3b91a7c0550545a03fdab67644d3 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase): [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) def testOneHot(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) op = array_ops.one_hot(indices, np.int32(4), @@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase): self.assertAllEqual(output, expected) def testSplitV(self): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = session.run( array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7..f985c5d2d96e06fc0117f3935d61b19c9e8562b1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = op() result = session.run(output) self.assertAllClose(result, expected, rtol=1e-3) def testNoOp(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): output = control_flow_ops.no_op() # This should not crash. diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index d68d32057a367776d5b70d5ac21d5618297c605d..7635f89249b7b71e5353e0b7cb1cea5c1f7bca1d 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -46,7 +46,7 @@ class OutOfMemoryTest(xla_test.XLATestCase): def test_loop(): size = int(2e8) while True: - with self.test_session(): + with self.cached_session(): # Force the compiled code to not be constant by feeding in a # parameter. p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index a75d99189b5b673261c9e48f1c5998ea0c575594..77bb839409f0c323ff6ed2c8d6bd105d3003b398 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 @@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase): self.assertEqual(8.0, sess.run(out)) def test_placeholder_with_default_fed(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 17f860db61aeda98326a6820771d67ee948b6dda..b6cdd38345b9a9f6b03e8799587e3f6ffe07b407 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase): # numbers from 1. x = np.arange(1.0, total_size + 1, dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = pool_func( inputs, @@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase): strides = [1] + strides + [1] total_size = np.prod(input_sizes) x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device("CPU"): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 9fc94752ea660f7fb8b2c792180f01485ad04419..d03bd4fdbb7694bc36291faf9b845ec48e26a386 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase): # numbers from 1. x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = inputs @@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase): # TODO(b/74222344): Fix nan handling for max pool grad. # x[np.random.choice(total_size)] = np.nan x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device(self.CPU_DEVICE): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 5fa7706d7294f2cffb7d24a56851be02d759335a..86536da7fed0e2309beb32fee9c7c605491592ed 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase): base=math.e, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index cde87db63dbfd7c8d823c6fd0e41eee8b23735bb..c41b4171e26af4f7ad0237d7407a5b3691299595 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad class ProximalAdagradOptimizerTest(xla_test.XLATestCase): def testResourceProximalAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -60,7 +60,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertEqual(2, len(opt_vars)) def testProximalAdagradwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -84,7 +84,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) def testProximalAdagradWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -108,7 +108,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) def testProximalAdagradWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -151,7 +151,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -159,7 +159,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1)) diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 11eb76871133eba8fcd24621afb03e16614fb005..3d808e6b8a71ef9fa60b671d07bfd907e9f58efc 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): def testResourceProximalGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) def testProximalGradientDescentwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) def testProximalGradientDescentWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) def testProximalGradientDescentWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0)) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 1b969ee2b3886fca6ec9951d1621ca5af6a673d8..236b1b881dcaffc1a5b0c6395f0605c1d7ef0269 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): x_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) - with self.test_session() as sess: + with self.cached_session() as sess: x_tf = array_ops.placeholder(dtype) with self.test_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) @@ -101,8 +101,8 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): @parameterized.parameters(*PARAMS) def testQR(self, rows, cols, dtype): - # TODO(b/111317468): implement full_matrices=False, test other types. - for full_matrices in [True]: + # TODO(b/111317468): Test other types. + for full_matrices in [True, False]: # Only tests the (3, 2) case for small numbers of rows/columns. for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): self._test(dtype, batch_dims + (rows, cols), full_matrices) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 8c4e16e4e075726d741f6ff8cdfb6b1aad6cd33e..6e183441179ebf2e8c063b333f9328d6fa86cc88 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -39,7 +39,7 @@ class RandomOpsTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype) @@ -79,7 +79,7 @@ class RandomOpsTest(xla_test.XLATestCase): if (self.device in ["XLA_GPU", "XLA_CPU" ]) and (dtype in [dtypes.bfloat16, dtypes.half]): continue - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) @@ -99,7 +99,7 @@ class RandomOpsTest(xla_test.XLATestCase): count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = sess.run(x) @@ -147,7 +147,7 @@ class RandomOpsTest(xla_test.XLATestCase): # TODO(b/26783907): this test requires the CPU backend to implement sort. if self.device in ["XLA_CPU"]: return - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) @@ -158,7 +158,7 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllEqual(set(result), set(expected)) def testShuffle2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index c0ea242044540b1cef44186880ba3cd92b8849d6..0faf0fd8edf355838ccf42f1d6de20ac01faa3db 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -275,13 +275,13 @@ class OpTest : public ::testing::Test { // Select a random element from 'candidates'. template - T Choose(gtl::ArraySlice candidates); + T Choose(absl::Span candidates); static constexpr int kDefaultMaxRank = 5; static constexpr int64 kDefaultMaxDimensionSize = 256LL; // Returns true if 'dims' have a size less than tf_xla_max_tensor_size. - bool TensorSizeIsOk(gtl::ArraySlice dims); + bool TensorSizeIsOk(absl::Span dims); // Returns a random dimension size, in the range [min, max). int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize); @@ -307,11 +307,11 @@ class OpTest : public ::testing::Test { // of the type's range. If the shape is omitted, a random shape is used. // TODO(phawkins): generalize this code to a caller-supplied distribution. Tensor RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape); + absl::Span shape); Tensor RandomTensor(DataType dtype); // Like RandomTensor, but uses values >= 0. - Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice shape); + Tensor RandomNonNegativeTensor(DataType dtype, absl::Span shape); Tensor RandomNonNegativeTensor(DataType dtype); // Returns a random subset of the integers in the range [0, rank), suitable @@ -415,7 +415,7 @@ void OpTest::Repeatedly(const std::function& fn) { } template -T OpTest::Choose(gtl::ArraySlice candidates) { +T OpTest::Choose(absl::Span candidates) { std::uniform_int_distribution d(0, candidates.size() - 1); return candidates[d(generator())]; } @@ -425,7 +425,7 @@ int64 OpTest::RandomDim(int64 min, int64 max) { return size_distribution(generator()); } -bool OpTest::TensorSizeIsOk(gtl::ArraySlice dims) { +bool OpTest::TensorSizeIsOk(absl::Span dims) { int64 size = 1LL; for (int64 dim : dims) { size *= dim; @@ -451,7 +451,7 @@ std::vector OpTest::RandomDims(int min_rank, int max_rank, } Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -548,7 +548,7 @@ Tensor OpTest::RandomTensor(DataType dtype) { } Tensor OpTest::RandomNonNegativeTensor(DataType dtype, - gtl::ArraySlice shape) { + absl::Span shape) { Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { @@ -1884,7 +1884,8 @@ TEST_F(OpTest, DynamicStitch) { for (int i = 0; i < n; ++i) { TensorShape shape(index_dims[i]); Tensor t = test::AsTensor( - gtl::ArraySlice(indices, pos, shape.num_elements()), shape); + absl::Span(indices).subspan(pos, shape.num_elements()), + shape); builder.Input(t); pos += t.NumElements(); } diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index cea2ec816f85e88b11e6e80c91c14fca9015f45c..132c59c32c9db0c8759bdbb31f8613c3ef88b485 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import itertools +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -30,22 +31,24 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(xla_test.XLATestCase): - +@parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) +class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, + index_dtype, rtol=1e-4, atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) - index = array_ops.placeholder(dtypes.int32) + index = array_ops.placeholder(index_dtype) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) self.assertAllClose( @@ -89,22 +92,23 @@ class ReduceOpsTest(xla_test.XLATestCase): np.array([[False, True, False], [True, True, False]]), ] - def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) + def testReduceSumF32(self, index_dtype): + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, + index_dtype) - def testReduceSumC64(self): + def testReduceSumC64(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceProdF32(self): + def testReduceProdF32(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceProdC64(self): + def testReduceProdC64(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceMin(self): + def testReduceMin(self, index_dtype): def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" @@ -119,9 +123,9 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_min, functools.partial(reference_min, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMax(self): + def testReduceMax(self, index_dtype): def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" @@ -137,23 +141,25 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_max, functools.partial(reference_max, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMeanF32(self): + def testReduceMeanF32(self, index_dtype): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_REAL_DATA) + self.NONEMPTY_REAL_DATA, index_dtype) - def testReduceMeanC64(self): + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, - self.NONEMPTY_COMPLEX_DATA) + self.NONEMPTY_COMPLEX_DATA, index_dtype) - def testReduceAll(self): - self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) + def testReduceAll(self, index_dtype): + self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA, + index_dtype) - def testReduceAny(self): - self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) + def testReduceAny(self, index_dtype): + self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA, + index_dtype) class ReduceOpPrecisionTest(xla_test.XLATestCase): @@ -178,7 +184,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): """ for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(dtypes.int32) @@ -213,7 +219,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): bf16_max = np.float32(dtypes.bfloat16.max) f32_max = dtypes.float32.max - value = min(bf16_max, f32_max - bf16_max) + value = min(bf16_max, f32_max - bf16_max) / 2 self._testReduceSum( dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index c69b6837b0f88ced844faf3713a29a1c14c8790d..ff20ea3f4287b4666684501fa4920435a77b4183 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(operand.dtype) with self.test_scope(): output = xla.reduce_window(placeholder, init, reducer, **kwargs) diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84c67779400f7a800bd88abc32d95058a6c0904d --- /dev/null +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slicing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): + + @parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) + def testBasic(self, index_dtype): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + shape = constant_op.constant([3, 2], dtype=index_dtype) + o = array_ops.reshape(i, shape) + params = { + i: [[1, 2, 3], [4, 5, 6]], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index d01c676e7c2fe705344f26818350c46c30451c67..392290fd92d0c7c928581422433892147374b2dd 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -32,33 +32,40 @@ class ReverseOpsTest(xla_test.XLATestCase): def testReverseOneDim(self): shape = (7, 5, 9, 11) - for revdim in range(len(shape)): + for revdim in range(-len(shape), len(shape)): self._AssertReverseEqual([revdim], shape) def testReverseMoreThanOneDim(self): shape = (7, 5, 9, 11) + # The offset is used to test various (but not all) combinations of negative + # and positive axis indices that are guaranteed to not collide at the same + # index. for revdims in itertools.chain.from_iterable( - itertools.combinations(range(len(shape)), k) - for k in range(2, len(shape)+1)): + itertools.combinations(range(-offset, + len(shape) - offset), k) + for k in range(2, + len(shape) + 1) + for offset in range(0, len(shape))): self._AssertReverseEqual(revdims, shape) def _AssertReverseEqual(self, revdims, shape): np.random.seed(120) pval = np.random.randint(0, 100, size=shape).astype(float) - with self.test_session(): + with self.cached_session(): with self.test_scope(): p = array_ops.placeholder(dtypes.int32, shape=shape) axis = constant_op.constant( np.array(revdims, dtype=np.int32), - shape=(len(revdims),), dtype=dtypes.int32) + shape=(len(revdims),), + dtype=dtypes.int32) rval = array_ops.reverse(p, axis).eval({p: pval}) slices = [ - slice(-1, None, -1) if d in revdims else slice(None) - for d in range(len(shape))] - self.assertEqual( - pval[slices].flatten().tolist(), - rval.flatten().tolist()) + slice(-1, None, -1) + if d in revdims or d - len(shape) in revdims else slice(None) + for d in range(len(shape)) + ] + self.assertEqual(pval[slices].flatten().tolist(), rval.flatten().tolist()) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index ccfa63001653537c4d1b7140e3d745c126f9034b..60c2337743b44e9bad61c4d65280eb2b1a1ad9ea 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): seq_lengths, truth, expected_err_re=None): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) with self.test_scope(): diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ff8bbac911abe73f946464663984ff1626302882..8840a1329a907bddc6ef1cb6dd1c2a6d234def5c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: for centered in [False, True]: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 4292352e76ebcef7dbf41df7b857d2604a468117..897db384b7e8067b0460b5f344201f101a4d8479 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( feed_dict={p: x}) @@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumsum(p, axis).eval(feed_dict={p: x}) @@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, @@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) prod = math_ops.cumprod(p, axis, exclusive, reverse) tf_out = prod.eval(feed_dict={p: x}) @@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumprod(x, axis).eval(feed_dict={p: x}) @@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index f606f88545d0b6f0b52cee9b93083a6bd91169bc..693f8513bc54e30060a2e963abd504768535a50a 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) def _runScatterNd(self, indices, updates, shape): - with self.test_session(): + with self.cached_session(): updates_placeholder = array_ops.placeholder(updates.dtype) indices_placeholder = array_ops.placeholder(indices.dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 772c20fd424577c3e06eeae409f424b77b52aa8a..287bb0d84e24de3bdcde3aa4c61acee00626e88f 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" def _segmentReduction(self, op, data, indices, num_segments): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 6c4890565d2083a9493abc59bd563c4dd9fdb186..2c611a959e1d71c53e44bc92c31258153d01507d 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.slice(i, [2], [4]) @@ -40,9 +40,22 @@ class SliceTest(xla_test.XLATestCase): self.assertAllEqual([2, 3, 4, 5], result) + def testZeroSlice(self): + for dtype in self.numeric_types: + with self.cached_session(): + i = array_ops.placeholder(dtype, shape=[2]) + with self.test_scope(): + o = array_ops.slice(i, [0], [0]) + params = { + i: [0, 1], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([], result) + def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) @@ -64,7 +77,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBegin(self): """Tests a slice where the start offset is not known at compile time.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -88,7 +101,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBeginAndNegativeSize(self): """Tests a slice where `begin` is fed dynamically and `size` contains -1.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -114,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [2], [6], [2]) @@ -127,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [6], [2], [-2]) @@ -140,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerate(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [-1, 0], [0, 3]) @@ -154,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerateNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) @@ -168,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) @@ -189,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 4, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 7ff01be3cb4848d6bb85b8ab96b3ee1db6889791..51c04b5c4796474700a92a8b23a1cbdf533fcbb4 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import test class XlaSortOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -131,7 +131,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=4) @@ -153,7 +153,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=6) diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c685bc548f9f6f8f7723c6f94dfd45f5420b4a67..33b84cec7188c85a3bacb20a6df29c73adbd107c 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) @@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase): def _testPad(self, inputs, block_shape, paddings, outputs): block_shape = np.array(block_shape) paddings = np.array(paddings).reshape((len(block_shape), 2)) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # TODO(b/68813416): Skip bfloat16's as the input type for direct is # float32 and results in a mismatch, while making testDirect provide the diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py index 3db8101c4bfbb1b53c7318a36519612984d6f179..07afd1ab3fb78d5accc52ee2382af0b9fb8079d3 100644 --- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices, class SparseToDenseTest(xla_test.XLATestCase): def testInt(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, 0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testFloat(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) self.assertAllClose(np_ans, tf_ans) def testSetValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testSetSingleValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, -1) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def test2d(self): # pylint: disable=bad-whitespace - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) np_ans = np.array([[-1, -1, -1, -1], [-1, -1, -1, 1], @@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testZeroDefault(self): - with self.test_session(): + with self.cached_session(): x = sparse_ops.sparse_to_dense(2, [4], 7).eval() self.assertAllEqual(x, [0, 0, 7, 0]) def test3d(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans[1, 3, 0] = 1 @@ -91,25 +91,25 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testBadShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): _SparseToDense([1, 3], [[5], [3]], 1, -1) def testBadValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[2,1\], " r"should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [[5], [3]], -1) def testBadNumValues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [1, 2, 3], -1) def testBadDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError("default_value should be a scalar"): _SparseToDense([1, 3], [5], [1, 2], [0]) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301..720595a159eea997be2246c4c7dad49612b257eb 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") @@ -41,7 +41,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -51,7 +51,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_push_v2(h1, v) @@ -66,7 +66,7 @@ class StackOpTest(xla_test.XLATestCase): def testSameNameStacks(self): """Different stacks with the same name do not interfere.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -84,14 +84,14 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(out2, 5.0) def testCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c = gen_data_flow_ops.stack_push_v2(h, v) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index d162675ef840131485128414b4a29e3cd89c8761..1bea7d9355e40c5a71f848dabc0fa7fa760429d2 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -38,7 +38,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 for stateless_op in [ @@ -55,7 +55,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertEqual(s0 == s1, np.all(v0 == v1)) def testRandomUniformIsInRange(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -74,7 +74,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -88,7 +88,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._chi_squared(y, 10) < 16.92) def testRandomNormalIsFinite(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -111,7 +111,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -126,7 +126,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testTruncatedNormalIsInRange(self): # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index f332aa2e9b97e13654cf9b10588c18fed32f7ad4..78244d0b366d9128a4c59f786e4c5ac12e743b75 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -44,7 +44,7 @@ def _make_converter(dtype): class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -66,7 +66,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -86,7 +86,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWritePack(dtype) def testEmptyTensorArrayPack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -100,7 +100,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([3, 0, 1], c0.eval().shape) def _testTensorArrayWriteConcat(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -121,7 +121,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteConcat(dtype) def _testTensorArrayUnpackRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -176,7 +176,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayUnpackReadMaybeLegacy() def _testTensorArraySplitRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -228,7 +228,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArraySplitRead(dtype) def testTensorGradArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -261,7 +261,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[-2.0]], g_d2) def testTensorGradArrayDynamicWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -300,7 +300,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(3, g_vs) def testTensorGradAccessTwiceReceiveSameObject(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3, element_shape=[1, 2]) @@ -317,7 +317,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[4.0, 5.0]], d_r1_0) def testTensorArrayWriteWrongIndexOrDataTypeFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -331,7 +331,7 @@ class TensorArrayTest(xla_test.XLATestCase): # the first type, but try to read the other type. if len(self.float_types) > 1: dtype1, dtype2 = list(self.float_types)[:2] - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) @@ -347,7 +347,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.read(1) def testTensorArraySplitIncompatibleShapesFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -379,7 +379,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta.split([1.0], [1]).flow.eval() def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) @@ -410,7 +410,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): h1 = tensor_array_ops.TensorArray( size=1, dtype=dtypes.float32, tensor_array_name="foo") w1 = h1.write(0, 4.0) @@ -425,7 +425,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllClose(9.0, r.eval()) def _testTensorArrayGradientWriteReadType(self, dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.as_dtype(dtype), tensor_array_name="foo", @@ -478,7 +478,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -513,7 +513,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWritePackConcatAndRead() def testTensorArrayReadTwice(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) ta_readtwice = tensor_array_ops.TensorArray( @@ -529,7 +529,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) def _testTensorArrayGradientUnpackRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -557,7 +557,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientUnpackRead() def testTensorArrayGradientSplitConcat(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=2) @@ -581,21 +581,21 @@ class TensorArrayTest(xla_test.XLATestCase): grad_vals[0]) def testCloseTensorArray(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c1 = ta.close() session.run(c1) def testSizeTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() self.assertAllEqual(3, s.eval()) def testWriteCloseTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -608,7 +608,7 @@ class TensorArrayTest(xla_test.XLATestCase): # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): # np_dtype = dtype.as_numpy_dtype - # with self.test_session() as session, self.test_scope(): + # with self.cached_session() as session, self.test_scope(): # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) @@ -692,7 +692,7 @@ class TensorArrayTest(xla_test.XLATestCase): # dynamic_size=True, dtype=dtypes.float32) # def testGradSerialTwoLoops(self): - # with self.test_session(), self.test_scope(): + # with self.cached_session(), self.test_scope(): # num_steps = 100 # acc = tensor_array_ops.TensorArray( # dtype=dtypes.float32, @@ -725,7 +725,7 @@ class TensorArrayTest(xla_test.XLATestCase): # self.assertAllClose(31.0, grad.eval()) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): a = array_ops.identity( np.arange( 3 * 5, dtype=np.float32).reshape(3, 5) + 1) @@ -757,7 +757,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(joint_grad_b_t, g0) def testWriteShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c0 = constant_op.constant([4.0, 5.0]) @@ -781,7 +781,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.write(0, c2) def testPartlyUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=6) @@ -821,7 +821,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) def _testUnpackShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -846,7 +846,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testUnpackShape() def testSplitShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -867,7 +867,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def testWriteUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -879,7 +879,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def _testGradientWhenNotAllComponentsRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) x = constant_op.constant([2.0, 3.0]) w = ta.unstack(x) @@ -893,7 +893,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testGradientWhenNotAllComponentsRead() def _testTensorArrayEvalEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=False) with self.assertRaisesOpError( @@ -906,7 +906,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmpty() def _testTensorArrayEvalEmptyWithDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=True) self.assertEqual(0, ta.size().eval()) @@ -921,7 +921,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmptyWithDefault() def testTensorArrayScatterReadAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -946,7 +946,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) def testTensorArrayWriteGatherAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -974,7 +974,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(expected_grad, grad_vals[0]) def testTensorArrayIdentity(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, infer_shape=False) ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index effa5a59fee7dda543b2c409dfaa27a972a55808..55a992195f2df72677b77757ae86171fa662439f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 73adb0d243b3b27e6c6ba669b2fd134a5976a2ec..5b0e57f83ff4b5a8d1891bef0675074bd67addce 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase): rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(inp.dtype), inp.shape, name="a") @@ -202,7 +202,7 @@ class UnaryOpsTest(xla_test.XLATestCase): # Disable float16 testing for now if dtype != np.float16: x = np.arange(-10, 10, 1).astype(dtype) - with self.test_session() as session: + with self.cached_session() as session: erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) @@ -396,6 +396,11 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array( [[True, False, True], [False, True, True]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array(0.5, dtype=dtype), + expected=np.array(np.log(np.pi) / 2, dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.lgamma, np.array( @@ -420,6 +425,19 @@ class UnaryOpsTest(xla_test.XLATestCase): ], dtype=dtype)) + # The actual result is complex. Take the real part. + self._assertOpOutputMatchesExpected( + math_ops.lgamma, + np.array([-1 / 2, -5 / 2, -9 / 2], dtype=dtype), + expected=np.array( + [ + np.log(np.pi) / 2 + np.log(2), + np.log(np.pi) / 2 - np.log(15) + np.log(8), + np.log(np.pi) / 2 - np.log(945) + np.log(32), + ], + dtype=dtype), + atol=1e-4) + self._assertOpOutputMatchesExpected( math_ops.digamma, np.array( diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720..4ee144beb7f3243be069d59ee4a613484fe183b3 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -43,7 +43,7 @@ class WhileTest(xla_test.XLATestCase): def loop_cond(step): return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) @@ -65,7 +65,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.float32, []) with self.test_scope(): @@ -91,7 +91,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): @@ -117,7 +117,7 @@ class WhileTest(xla_test.XLATestCase): del x return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 85084bb1240cf05f6eabfbea772df113cabe613c..28d61fb07dcb665fa0dbe3f3e566e291e24fa662 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase): [16384, 1], [1, 16384], [1, 20000, 1, 1]] for dtype in self.numeric_types: for shape in shapes: - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("CPU"): x = array_ops.placeholder(dtype, shape) with self.test_scope(): @@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase): ]) shape = (10, 10) for unsupported_dtype in test_types - self.all_types: - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("CPU"): x = array_ops.placeholder(unsupported_dtype, shape) with self.test_scope(): @@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase): pass def testControlTrigger(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() sess.run(x) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f026df6c0c28fcbceaa0493871bc12c2d23b1f --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, + xla_data_pb2.PrecisionConfigProto.HIGH, + xla_data_pb2.PrecisionConfigProto.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index fda32c8a1c9491e0dadceec0d7265e1002d41528..0797b2cb17f5aae4080f339a201b44d69bbb2187 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -88,6 +89,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -211,6 +213,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -220,13 +224,11 @@ cc_library( srcs = [ "literal_util.cc", "shape_util.cc", - "str_util.cc", "type_util.cc", ], hdrs = [ "literal_util.h", "shape_util.h", - "str_util.h", "type_util.h", ], visibility = [":friends"], @@ -238,6 +240,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", ], ) @@ -255,6 +258,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -287,6 +291,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", ], ) @@ -305,6 +310,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -372,19 +378,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - ], -) - -tf_cc_test( - name = "str_util_test", - srcs = [ - "str_util_test.cc", - ], - deps = [ - ":common", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -442,22 +436,97 @@ cc_library( ], ) +cc_library( + name = "functionalize_control_flow_util", + srcs = [ + "functionalize_control_flow_util.cc", + ], + hdrs = [ + "functionalize_control_flow_util.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "functionalize_cond", + srcs = [ + "functionalize_cond.cc", + ], + hdrs = [ + "functionalize_cond.h", + ], + deps = [ + ":functionalize_control_flow_util", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "functionalize_control_flow", - srcs = ["functionalize_control_flow.cc"], - hdrs = ["functionalize_control_flow.h"], + srcs = [ + "functionalize_control_flow.cc", + ], + hdrs = [ + "functionalize_control_flow.h", + ], deps = [ + ":functionalize_cond", + ":functionalize_control_flow_util", + ":functionalize_while", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "functionalize_while", + srcs = [ + "functionalize_while.cc", + ], + hdrs = [ + "functionalize_while.h", + ], + deps = [ + ":functionalize_control_flow_util", + ":tf2xla_util", + "//tensorflow/compiler/jit:union_find", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -485,6 +554,32 @@ tf_cc_test( ], ) +tf_cc_test( + name = "functionalize_cond_test", + srcs = ["functionalize_cond_test.cc"], + deps = [ + ":functionalize_cond", + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:xla_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "test_util", testonly = 1, @@ -508,3 +603,30 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "resource_operation_table", + srcs = ["resource_operation_table.cc"], + hdrs = ["resource_operation_table.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "resource_operation_table_test", + srcs = ["resource_operation_table_test.cc"], + deps = [ + ":resource_operation_table", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index de1008803d69fefa415c7bdbe6c27a62e625b417..e8673d77903bd5a1a85412e9dfa86437f73d56bc 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { - // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_args) { + std::vector* compile_time_const_args, + std::vector* compile_time_const_nodes) { // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set metadata_ops = { "Rank", @@ -36,9 +36,16 @@ Status BackwardsConstAnalysis(const Graph& g, "Size", }; + std::vector compile_time_const_nodes_impl; + if (compile_time_const_nodes) { + CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); + } else { + compile_time_const_nodes_impl.resize(g.num_node_ids()); + compile_time_const_nodes = &compile_time_const_nodes_impl; + } + Status status; - std::unordered_set must_be_const; - auto visit = [&status, &metadata_ops, &must_be_const, + auto visit = [&status, &metadata_ops, compile_time_const_nodes, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g, // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. - if (must_be_const.find(node) != must_be_const.end()) { + if ((*compile_time_const_nodes)[node->id()]) { if (node->type_string() == "_Arg") { int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - compile_time_const_args->at(index) = true; + if (compile_time_const_args) { + (*compile_time_const_args)[index] = true; + } return; } for (const Edge* pred : node->in_edges()) { if (!pred->IsControlEdge()) { - must_be_const.insert(pred->src()); + (*compile_time_const_nodes)[pred->src()->id()] = true; } } return; @@ -80,7 +89,7 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && edge->dst_input() < name_range->second.second) { - must_be_const.insert(edge->src()); + (*compile_time_const_nodes)[edge->src()->id()] = true; } } } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 634b97d7e3760c0344c948a56353ade243284aa6..af57e5a4033248e3fd32dabeda252c4ca0a44050 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -23,10 +23,18 @@ limitations under the License. namespace tensorflow { -// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that -// must be compile-time constants. +// Backwards dataflow analysis that finds nodes in a graph that must be +// compile-time constants for us to be able to lower the graph to XLA. +// +// The indices of the arguments to `graph` that must be constant are returned in +// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not +// null. +// +// The ids of the nodes in `graph` that must be constant are returned in +// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. Status BackwardsConstAnalysis(const Graph& graph, - std::vector* compile_time_const_args); + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 992b12c06db5efc0ae54284d0ea77017c1c79aca..56065be894697bc72ecc0089c665c19aafee7bf8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) { auto c = ops::Reshape(root, arg2, b); auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3)); - Graph graph(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph)); + FixupSourceAndSinkEdges(root.graph()); std::vector const_args(4, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + std::vector const_nodes(root.graph()->num_node_ids(), false); + TF_ASSERT_OK( + BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes)); // Arg 0 doesn't need to be constant since the graph only uses its shape. // Arg 1 must be constant because it flows to the shape argument of a Reshape. // Arg 2 is used only as the value input to a Reshape and need not be const. // Arg 3 is used as the reduction-indices argument to Sum and must be const. EXPECT_EQ(const_args, std::vector({false, true, false, true})); + + EXPECT_FALSE(const_nodes[arg0.node()->id()]); + EXPECT_TRUE(const_nodes[arg1.node()->id()]); + EXPECT_FALSE(const_nodes[arg2.node()->id()]); + EXPECT_TRUE(const_nodes[arg3.node()->id()]); } // Regression test for a case where the backward const analysis did @@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(3, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({true, true, false})); } @@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(2, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({false, true})); } diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc new file mode 100644 index 0000000000000000000000000000000000000000..b5667ca0d3ba35bea9da2d702b5b49fb38fe6f02 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -0,0 +1,1385 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" + +using xla::StatusOr; + +namespace tensorflow { +namespace functionalize_cond { + +string DebugString(const CondStateMap::CondNode& node) { + return node.ToString(); +} + +// TODO(jpienaar): Move to OutputTensor. +string DebugString(const OutputTensor& tensor) { + return strings::StrCat(tensor.node->name(), ":", tensor.index); +} + +string DebugString(CondStateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "[]"; + return strings::StrCat( + "[", + absl::StrJoin(*cond_state, ", ", + [](string* output, const CondStateMap::CondNode& node) { + strings::StrAppend(output, node.ToString()); + }), + "]"); +} + +string Branch_Name(BranchType b) { + switch (b) { + case BranchType::kElseBranch: + return "else"; + case BranchType::kThenBranch: + return "then"; + case BranchType::kBoth: + return "both"; + case BranchType::kNeither: + return "neither"; + } +} + +// Returns the predicate of a switch. +Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { + const Edge* pred_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through + // identity nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_RETURN_IF_ERROR(pred_edge->src()->input_edge(0, &pred_edge)); + } + *pred = OutputTensor(pred_edge->src(), pred_edge->src_output()); + return Status::OK(); +} + +CondStateMap::CondNode::CondNode(Type type, Node* switch_node, + BranchType branch) + : type(type), branch(branch) { + if (type == Type::kSwitch) { + TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); + } +} + +string CondStateMap::CondNode::ToString() const { + switch (type) { + case Type::kSwitch: + return strings::StrCat("s(", DebugString(predicate), ",", + Branch_Name(branch), ")"); + case Type::kMerge: + return "m"; + case Type::kDead: + return "d"; + } +} + +bool CondStateMap::CondNode::operator==(const CondNode& other) const { + if (type != Type::kSwitch) return type == other.type; + return type == other.type && predicate == other.predicate && + branch == other.branch; +} + +bool CondStateMap::CondNode::operator!=(const CondNode& other) const { + return !(*this == other); +} + +CondStateMap::CondStateMap(Graph* graph) { + node_to_condid_map_.resize(graph->num_node_ids()); + // Initialize the dead state (empty state is designated with a nullptr). + dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); +} + +bool CondStateMap::IsDead(CondStateMap::CondId id) const { + return id == dead_id_; +} + +bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { + return id == nullptr; +} + +size_t CondStateMap::CondHash::operator()( + const CondStateMap::CondNode& item) const { + return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), + hash()(item.branch)), + hash()(item.type)); +} + +size_t CondStateMap::CondHash::operator()( + const CondStateMap::CondState& vec) const { + if (vec.empty()) return 0; + size_t h = (*this)(vec.front()); + auto it = vec.begin(); + for (++it; it != vec.end(); ++it) { + h = Hash64Combine(h, (*this)(*it)); + } + return h; +} + +// CondArgNode represents a input to the conditional and its corresponding +// switch nodes. +struct CondArgNode { + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} + + string ToString() const { + return strings::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); + } + + Node* src; + int src_output; + std::array branch_copy; + std::vector switches; +}; +using CondArgNodes = std::vector; + +string DebugString(const CondArgNodes& nodes) { + return strings::StrCat( + "[", + absl::StrJoin(nodes, ", ", + [](string* output, const CondArgNode& node) { + strings::StrAppend(output, node.ToString()); + }), + "]"); +} + +CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { + if (node->id() < node_to_condid_map_.size()) + return node_to_condid_map_[node->id()]; + return added_node_mapping_.at(node->id()); +} + +CondStateMap::CondId CondStateMap::GetUniqueId( + const CondStateMap::CondState& state) { + if (state.empty()) return nullptr; + return &*condstate_set_.insert(state).first; +} + +const CondStateMap::CondState& CondStateMap::LookupState( + const Node* node) const { + return *LookupId(node); +} + +void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { + if (node->id() < node_to_condid_map_.size()) + node_to_condid_map_[node->id()] = id; + else + added_node_mapping_[node->id()] = id; +} + +void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } + +string CondStateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupId(node)); +} + +string CondStateMap::CondStateToString(CondStateMap::CondId id) const { + return DebugString(id); +} + +FunctionalizeCond::FunctionalizeCond(Graph* graph, + FunctionLibraryDefinition* library) + : cond_state_map_(graph), library_(library), graph_(graph) {} + +// Class representing the merge/switch nodes that will become a conditional. +class Conditional { + public: + Conditional(OutputTensor predicate, FunctionalizeCond* parent, + CondStateMap* cond_state_map); + + // Adds merge node that is part of this conditional. + Status AddMerge(Node* m); + + // Constructs an If node from the merge nodes. + Status BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library); + + private: + // Extracts the then/else bodies: creates new graphs with the nodes + // corresponding to the nodes in the then/else branches as of this conditional + // as function bodies. + Status ExtractBodies(Graph* graph); + + // Builds the arguments that are the input to the If. + Status BuildArgumentNodes(); + + // Builds the If node for the extracted bodies with the given predicate. + Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); + + // Adds input edges to If node. + Status AddInputEdges(Graph* graph); + + // Adds output edges from If node. + Status AddOutputEdges(Graph* graph); + + // Adds switch node that is part of this conditional. + Status AddSwitch(Node* s); + + // Internal name of conditional. The name is based on the first merge node + // added. + string name() const; + + // The FunctionalizeCond instance that created this. + FunctionalizeCond* parent_; + + // Mapping between nodes and their cond state. + CondStateMap* cond_state_map_; + + // The predicate of the conditional. + OutputTensor predicate_; + + // The predicate of the switches of the conditional. This may be different + // than predicate (which is initialized from the original graph) as the + // predicate could be the output of a newly created If node. + OutputTensor switch_predicate_; + + // Switch nodes in graph that are part of this conditional. + std::set switches_; + + // Merge nodes in graph that are part of this conditional. + std::set merges_; + + // Vector of control inputs from outside the conditional to a node inside. + std::vector external_control_inputs_; + std::vector external_control_outputs_; + + // Graphs corresponding to the then and else branch. + std::array, 2> bodies_; + + // Maps from graph_ to the branch body's graph. + std::array, 2> node_maps_; + + // The argument nodes created for the switches. + CondArgNodes cond_arg_nodes_; + + // The constructed If node. + Node* if_node_ = nullptr; + + // Whether the merge nodes of this conditional have been replaced. + bool replaced_ = false; +}; + +Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, + CondStateMap* cond_state_map) + : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + +Status Conditional::AddMerge(Node* m) { + merges_.insert(m); + return Status::OK(); +} + +Status Conditional::AddSwitch(Node* s) { + VLOG(5) << "Adding switch " << s->DebugString(); + OutputTensor predicate; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate)); + if (switch_predicate_.node == nullptr) switch_predicate_ = predicate; + if (!(switch_predicate_ == predicate)) { + return errors::InvalidArgument( + "Merge nodes ", NodesToString(merges_), + " directly dominated by switch nodes with different predicates (", + DebugString(switch_predicate_), " vs ", DebugString(predicate), ")."); + } + switches_.insert(s); + return Status::OK(); +} + +Status Conditional::BuildArgumentNodes() { + VLOG(1) << "Build function arguments"; + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + + std::unordered_map, int, Hash> input_index; + for (Node* switch_node : switches_) { + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes_.size(); + cond_arg_nodes_.emplace_back(key.first, key.second); + } + cond_arg_nodes_.at(input_index.at(key)).switches.push_back(switch_node); + } + VLOG(5) << "CondArg nodes created: " << DebugString(cond_arg_nodes_); + + int arg_count = 0; + for (CondArgNode& cond_arg_node : cond_arg_nodes_) { + DataType dtype = cond_arg_node.src->output_type(cond_arg_node.src_output); + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_Arg", arg_count), + FunctionLibraryDefinition::kArgOp) + .Attr("T", dtype) + .Attr("index", arg_count) + .Finalize(bodies_[branch_index].get(), + &cond_arg_node.branch_copy[branch_index])); + } + for (Node* node : cond_arg_node.switches) { + for (const Edge* e : node->out_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = e->src_output(); + Node* src_copy = cond_arg_node.branch_copy[branch_index]; + Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; + + // The graph may contain dead switch nodes, + if (dst_copy == nullptr) continue; + + TF_RET_CHECK(dst_copy != nullptr) + << "Unable to find copied node for " << e->dst()->DebugString() + << " on branch " << Branch_Name(BranchType(branch_index)); + // If the input goes directly to a merge then the merge has + // been replaced by a retval so the dst input is 0 instead of + // dst_input. + int dst_input = IsMerge(e->dst()) ? 0 : e->dst_input(); + bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); + } + } + ++arg_count; + } + + // Verify that all retvals have an input. + // TODO(jpienaar): One could add a ZerosLike in the branch that doesn't have + // input. + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bool has_input = false; + for (auto e : node_maps_[static_cast(branch)][m->id()]->in_edges()) { + if (!e->IsControlEdge()) { + has_input = true; + break; + } + } + if (!has_input) { + return errors::Internal( + "Failed to functionalize control flow with merge ", + FormatNodeForError(*m), " that doesn't have input on ", + Branch_Name(branch), " branch."); + } + } + } + + return Status::OK(); +} + +Status Conditional::ExtractBodies(Graph* graph) { + VLOG(2) << "Extracting bodies for " << name(); + for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { + bodies_[static_cast(b)] = + absl::make_unique(graph->op_registry()); + } + + auto find_branch = [&](const Edge* e) { + const auto& id = cond_state_map_->LookupId(e->src()); + return IsSwitch(e->src()) ? BranchType(e->src_output()) + : cond_state_map_->FindBranchOf(id, predicate_); + }; + + std::array, 2> stacks; + VLOG(5) << "Merges: " << NodesToString(merges_); + for (Node* m : merges_) { + VLOG(5) << "For merge: " << m->DebugString() << " " + << cond_state_map_->CondStateToString(m); + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + BranchType branch = find_branch(e); + TF_RET_CHECK(branch == BranchType::kThenBranch || + branch == BranchType::kElseBranch) + << "Error: " << e->src()->name() + << " is not on either then or else branch (" << Branch_Name(branch) + << ")."; + Node* src = e->src(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + } else { + stacks[static_cast(branch)].push_back(src); + } + } + } + + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + auto& stack = stacks[branch_index]; + VLOG(5) << "In branch: " << Branch_Name(branch) << " " + << NodesToString(stack); + std::vector visited(graph->num_node_ids(), false); + node_maps_[branch_index].resize(graph->num_node_ids(), nullptr); + auto& node_map = node_maps_[branch_index]; + + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + if (visited.at(n->id())) continue; + visited[n->id()] = true; + + // Verify output edges and record control edges exitting scope. + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + if (IsMerge(dst)) continue; + Node* src = e->src(); + + auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = cond_state_map_->LookupId(src); + if (dst_id != src_id) { + if (e->IsControlEdge()) { + external_control_outputs_.push_back(e->src()); + } else { + // Constants are treated specially to workaround the case of + // non-dominated constant nodes. + if (!IsConstant(src)) { + // TODO(b/78882471): A node that feeds into two different + // CondState is not necessarily an error so log a warning for now + // but revisit to improve the testing to enable making this an + // error. + LOG(WARNING) << errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during out edge testing)"); + } + } + } + } + + // Copying incomming edges to dst node. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + // Skip src/dst node. + if (!src->IsOp()) continue; + + Node* dst = e->dst(); + if (IsSwitch(src)) { + // Switch node outputs and dependencies are handled separately. + TF_RETURN_IF_ERROR(AddSwitch(src)); + continue; + } + + // Verify input is from the same context. + auto src_id = cond_state_map_->LookupId(src); + auto dst_id = cond_state_map_->LookupId(dst); + if (IsMerge(dst) || src_id == dst_id) { + // TODO(jpienaar): The merge case can be more strict. + if (node_map.at(src->id()) == nullptr) { + node_map.at(src->id()) = output->CopyNode(src); + stack.push_back(src); + } + } else if (e->IsControlEdge()) { + external_control_inputs_.push_back(src); + } else { + // This shouldn't happen, this means we have an external data input + // not entering via a switch node. Work around this for constant + // nodes as some constant nodes are inserted without the required + // control context dominance. + if (IsConstant(src)) { + node_map.at(src->id()) = output->CopyNode(src); + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } + } + + Node* src_copy = node_map.at(e->src()->id()); + int src_output = e->src_output(); + if (node_map.at(dst->id()) == nullptr) { + node_map.at(dst->id()) = output->CopyNode(dst); + } + Node* dst_copy = node_map.at(e->dst()->id()); + if (e->IsControlEdge()) { + // Skip control inputs from external context. + if (src_copy != nullptr) output->AddControlEdge(src_copy, dst_copy); + } else { + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + } + } + + // Build return values from the merge nodes. + int index = 0; + for (Node* m : merges_) { + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + TF_ASSIGN_OR_RETURN(node_map[m->id()], + BuildRetvalNode(output, m->output_type(0), index)); + } + ++index; + + // Connect the input to the merge_ with the retval, except if it is a + // Swich node, which is handled separately. + for (auto e : m->in_edges()) { + if (e->IsControlEdge()) continue; + int branch_index = static_cast(find_branch(e)); + auto& node_map = node_maps_[branch_index]; + auto output = bodies_[branch_index].get(); + Node* in = e->src(); + if (!IsSwitch(in)) { + if (node_map.at(in->id()) == nullptr) { + node_map[in->id()] = output->CopyNode(in); + } + output->AddEdge(node_map[in->id()], e->src_output(), + node_map.at(m->id()), 0); + } + } + } + return Status::OK(); +} + +Status Conditional::BuildIfNode(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "Build cond function for " << name(); + NodeDefBuilder builder(name(), "If"); + const string branch_name[] = {"else_branch", "then_branch"}; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_if_", + branch_name[branch_index], "_", id)); + + VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] + << "): " + << dump_graph::DumpGraphToFile( + "functionalize_cond_body_" + branch_name[branch_index], + *bodies_[branch_index], nullptr); + + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*bodies_[branch_index], + body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + builder.Attr(branch_name[branch_index], body_name); + } + + VLOG(3) << "Build input type"; + std::vector inputs; + DataTypeVector in_arg_types; + for (auto& kv : cond_arg_nodes_) { + bool inserted = false; + for (const Node* arg : kv.switches) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + if (!inserted) { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + inserted = true; + } + } + } + } + builder.Attr("Tin", in_arg_types); + + DataTypeVector out_type; + for (const Node* merge : merges_) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + VLOG(3) << "Build output type: " << DataTypeVectorString(out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(predicate_.node->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(predicate_.node->name(), + predicate_.index, + predicate_.node->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + VLOG(3) << "Build If node"; + NodeDef if_def; + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + + return Status::OK(); +} + +Status Conditional::AddInputEdges(Graph* graph) { + VLOG(2) << "AddInputEdges for " << if_node_->name(); + int index = 0; + // Add predicate input. + graph->AddEdge(const_cast(predicate_.node), predicate_.index, if_node_, + index++); + // Add function body inputs. + for (auto& arg : cond_arg_nodes_) { + if (arg.src_output == Graph::kControlSlot) { + graph->AddControlEdge(arg.src, if_node_); + } else { + graph->AddEdge(arg.src, arg.src_output, if_node_, index++); + } + } + for (Node* n : external_control_inputs_) { + graph->AddControlEdge(n, if_node_); + } + return Status::OK(); +} + +Status Conditional::AddOutputEdges(Graph* graph) { + VLOG(2) << "AddOutputEdges for " << if_node_->name(); + int i = 0; + for (Node* node : merges_) { + TF_RETURN_IF_ERROR(parent_->AddIdentityNode(node, if_node_, i)); + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", + FormatNodeForError(*node)); + } + + bool control_edge = edge->IsControlEdge(); + graph->RemoveEdge(edge); + if (control_edge) { + graph->AddControlEdge(if_node_, dst); + } else { + graph->AddEdge(if_node_, i, dst, dst_input); + } + } + ++i; + } + for (Node* n : external_control_outputs_) { + graph->AddControlEdge(if_node_, n); + } + + return Status::OK(); +} + +Status Conditional::BuildAndReplace(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "Build If and replace merge nodes " << name(); + if (replaced_) return Status::OK(); + + TF_RETURN_IF_ERROR(ExtractBodies(graph)); + TF_RETURN_IF_ERROR(BuildArgumentNodes()); + + if (VLOG_IS_ON(3)) { + LOG(INFO) << "Extracted bodies:"; + for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { + int branch_index = static_cast(branch); + auto output = bodies_[branch_index].get(); + LOG(INFO) << Branch_Name(branch) << ": " + << DebugString(output->ToGraphDefDebug()); + } + } + + TF_RETURN_IF_ERROR(BuildIfNode(graph, library)); + TF_RETURN_IF_ERROR(AddInputEdges(graph)); + TF_RETURN_IF_ERROR(AddOutputEdges(graph)); + TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); + for (Node* m : merges_) cond_state_map_->MarkDead(m); + + // Check that the if_node doesn't feed into itself. + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(if_node_, graph->num_node_ids()), + "Converting to If failed."); + + replaced_ = true; + return Status::OK(); +} + +string Conditional::name() const { + CHECK(!merges_.empty()); + return strings::StrCat((*merges_.begin())->name(), "_if"); +} + +bool CondStateMap::ScopeIn(CondStateMap::CondId id, + CondStateMap::CondId* scope) { + if (id == nullptr) { + *scope = nullptr; + return true; + } + CondState state; + for (const CondNode& node : *id) { + if (node.type == CondNode::Type::kSwitch) { + state.push_back(node); + } + if (node.type == CondNode::Type::kMerge) { + if (state.empty()) { + return false; + } + DCHECK(state.back().type == CondNode::Type::kSwitch && + state.back().branch == BranchType::kBoth); + state.pop_back(); + } + } + *scope = GetUniqueId(state); + return true; +} + +Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, + int port) { + Node* id; + TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") + .Input(if_node, port) + .Finalize(graph_, &id)); + cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + return Status::OK(); +} + +StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, + const Node* replacee) { + Status status; + Node* ret = graph_->AddNode(def, &status); + TF_RETURN_IF_ERROR(status); + CondStateMap::CondState state = cond_state_map_.LookupState(replacee); + state.pop_back(); + VLOG(1) << "Adding If for " << replacee->name(); + cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + return ret; +} + +Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { + VLOG(2) << "Propagating update state for " << replacee->name() << " " + << cond_state_map_.CondStateToString(replacee); + // Redo topological sort as the order could have changed. + // TODO(jpienaar): The original topological order could also be updated + // dynamically if needed. + std::vector rev_topo_order; + GetPostOrder(*graph_, &rev_topo_order); + + // All the outputs of the new node could potentially be updated. + std::unordered_set changed; + for (auto n : replacee->out_nodes()) + if (n->IsOp()) changed.insert(n); + + // Iterate through the changed/possible changed nodes in topological order. + for (auto it = rev_topo_order.rbegin(); + it != rev_topo_order.rend() && !changed.empty(); ++it) { + if (changed.find(*it) != changed.end()) { + // Update the node state. + Node* n = *it; + CondStateMap::CondId old_state = cond_state_map_.LookupId(n); + cond_state_map_.ResetId(n, nullptr); + TF_RETURN_IF_ERROR(DetermineCondState(n)); + if (cond_state_map_.LookupId(n) != old_state) { + for (auto out : n->out_nodes()) + if (out->IsOp()) changed.insert(out); + } + changed.erase(n); + } + } + return Status::OK(); +} + +// Returns the most restrictive branch of two branches or neither. This is the +// meet operator of the BranchType lattice. +BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { + if (lhs == rhs) return lhs; + if (lhs == BranchType::kNeither) return rhs; + if (rhs == BranchType::kNeither) return lhs; + if (lhs == BranchType::kBoth) return rhs; + if (rhs == BranchType::kBoth) return lhs; + return BranchType::kNeither; +} + +CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( + CondStateMap::CondId lhs, CondStateMap::CondId rhs) { + CondId lhs_scope; + CondId rhs_scope; + bool could_determine_scope = ScopeIn(lhs, &lhs_scope); + could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); + if (!could_determine_scope) return kIncomparable; + + // Returns whether a contains b. + auto contains = [&](CondId a, CondId b) { + // Handle empty states. + if (a == nullptr && b != nullptr) return true; + if (a == nullptr && b == nullptr) return true; + if (a != nullptr && b == nullptr) return false; + + if (a->size() > b->size()) return false; + auto a_it = a->begin(); + auto b_it = b->begin(); + while (a_it != a->end()) { + if (*a_it != *b_it) { + if (!(a_it->predicate == b_it->predicate)) return false; + BranchType mb = MeetBranch(a_it->branch, b_it->branch); + if (mb != b_it->branch) return false; + } + ++a_it; + ++b_it; + } + return true; + }; + + bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); + bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); + if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; + if (lhs_contains_rhs) return kLhsContainsRhs; + if (rhs_contains_lhs) return kRhsContainsLhs; + return kIncomparable; +} + +BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { + if (IsEmpty(id)) return BranchType::kNeither; + absl::optional b; + const CondState& nodes = *id; + for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { + if (it->type == CondStateMap::CondNode::Type::kSwitch && + it->predicate == predicate) { + if (b.has_value()) { + b = MeetBranch(*b, it->branch); + } else { + b = it->branch; + } + if (*b == BranchType::kNeither) { + LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); + } + } + } + return b.has_value() ? *b : BranchType::kNeither; +} + +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + VLOG(4) << "Joining src=" << DebugString(src) << " [" << src + << "] and dst=" << DebugString(dst) << " [" << dst << "]"; + + if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; + if (cond_state_map_.IsDead(dst)) return dst; + + // Nothing to do if the CondState is the same. + if (src == dst) return src; + + CondStateMap::CondId src_scope; + CondStateMap::CondId dst_scope; + if (!cond_state_map_.ScopeIn(src, &src_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(src)); + if (!cond_state_map_.ScopeIn(dst, &dst_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(dst)); + + auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); + switch (result) { + case CondStateMap::kIncomparable: + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + case CondStateMap::kEqual: + // If both respect the same predicates, propagate the longer constraint. + if ((src != nullptr && dst == nullptr) || + (src != nullptr && dst != nullptr && src->size() > dst->size())) + return src; + else + return dst; + case CondStateMap::kLhsContainsRhs: + // src contains dst, so dst is already more restrictive. + return dst; + case CondStateMap::kRhsContainsLhs: + // dst contains src, so src is more restrictive. + return src; + } +} + +StatusOr +FindThenElseSwitchForPredicate(const OutputTensor& pred, + CondStateMap::CondId id) { + for (auto it = id->begin(); it != id->end(); ++it) { + // Along every path one there can be only one instance of a then or else + // switch for a given predicate, so return once found. + if (it->type == CondStateMap::CondNode::Type::kSwitch && + it->predicate == pred && + (it->branch == BranchType::kThenBranch || + it->branch == BranchType::kElseBranch)) + return it; + } + return errors::Internal("Unable to find then/else branch with predicate ", + DebugString(pred), " for ", DebugString(id)); +} + +StatusOr FunctionalizeCond::JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + // Determine the flow state when joining two states for a merge + // node. Combining the two states for a merge node is effectively performing a + // disjunction of the states along the different input edges. For a merge that + // can be transformed into a If the two inputs paths have to have a predicate + // on which they differ (e.g., along one edge predicate `p` has to hold while + // on another it should not). This function first determines this predicate + // and then the resultant state is the common path between the two inputs + // followed by s(p, both). + VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " + << DebugString(dst); + if (cond_state_map_.IsEmpty(dst)) return src; + + if (cond_state_map_.IsDead(src)) return src; + if (cond_state_map_.IsDead(dst)) return dst; + + CondStateMap::CondId src_scope; + CondStateMap::CondId dst_scope; + if (!cond_state_map_.ScopeIn(src, &src_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(src)); + if (!cond_state_map_.ScopeIn(dst, &dst_scope)) + return errors::Unimplemented( + "Predicates that must hold for node to execute are invalid! ", + DebugString(dst)); + + TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) + << "Illegal merge inputs from outer scope: src=" << DebugString(src) + << " dst=" << DebugString(dst); + auto src_it = src_scope->begin(); + auto dst_it = dst_scope->begin(); + + // Find branch divergent condition. + OutputTensor pred; + while (src_it != src_scope->end() && dst_it != dst_scope->end()) { + if (*src_it != *dst_it) { + VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " + << DebugString(*dst_it); + if (!(src_it->predicate == dst_it->predicate)) { + return errors::InvalidArgument( + "Unable to find common predicate which holds for one input " + "but not the other of the merge node."); + } + pred = src_it->predicate; + break; + } + ++src_it; + ++dst_it; + } + + if (pred.node == nullptr) + return errors::InvalidArgument("Unable to determine predicate for merge."); + + TF_ASSIGN_OR_RETURN(auto div_src_it, + FindThenElseSwitchForPredicate(pred, src)); + TF_ASSIGN_OR_RETURN(auto div_dst_it, + FindThenElseSwitchForPredicate(pred, dst)); + TF_RET_CHECK(*div_src_it != *div_dst_it); + + CondStateMap::CondState result; + // Populate result with the longest/most restrictive path up to the divergent + // node. For example, if the one input is `[switch(pred:0, then)]` and the + // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created + // in gradient of cond test), then the resultant state here should be + // `[switch(pred:0, both), merge, switch(pred:0, both)]`. + if (std::distance(src->begin(), div_src_it) > + std::distance(dst->begin(), div_dst_it)) { + result.assign(src->begin(), std::next(div_src_it)); + } else { + result.assign(dst->begin(), std::next(div_dst_it)); + } + result.back().branch = BranchType::kBoth; + return cond_state_map_.GetUniqueId(result); +} + +CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { + Node* src = e->src(); + CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); + if (IsMerge(src)) { + CondStateMap::CondState state; + if (id != nullptr) state = *id; + state.emplace_back(CondStateMap::CondNode::Type::kMerge); + return cond_state_map_.GetUniqueId(state); + } + if (IsSwitch(src)) { + CondStateMap::CondState state; + if (id != nullptr) state = *id; + if (e->IsControlEdge()) { + state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, + BranchType::kBoth); + } else { + state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, + BranchType(e->src_output())); + } + return cond_state_map_.GetUniqueId(state); + } + return id; +} + +Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { + // Only Merge nodes with two inputs are supported, but if this is a redundant + // merge, then the dead edge may already have been removed (if due to a + // switch) and so the input count would be incorrect. + if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) + return Status::OK(); + + int data_inputs = 0; + for (auto e : dst->in_edges()) { + Node* src = e->src(); + VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " + << cond_state_map_.CondStateToString(src); + if (!src->IsOp()) continue; + if (!e->IsControlEdge()) ++data_inputs; + + CondStateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + } + + // Incomplete Merge nodes are not supported. + if (data_inputs != 2) { + return errors::Unimplemented( + dst->name(), " only has ", data_inputs, + " inputs, while only merge nodes with two inputs supported."); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineCondState(Node* dst) { + // The logic for the merge and non-merge case differ: for non-merge it is + // the most restrictive CondState, while for merge nodes the + // resultant state is less restrictive than either. + if (IsMerge(dst)) { + TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); + } else { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " + << cond_state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + CondStateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + } + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { + // Handle redundant merge nodes. A merge node is considered redundant if + // one input edge is dead while the other has a value. + if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) + return Status::OK(); + + const Edge* non_dead_edge = nullptr; + for (auto e : node->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + + // Handle merge with dead state. + const auto& src_id = cond_state_map_.LookupId(src); + if (!cond_state_map_.IsDead(src_id)) { + non_dead_edge = e; + break; + } + } + + if (non_dead_edge == nullptr) { + return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), + " has no non-dead inputs."); + } + cond_state_map_.MarkDead(node); + delete_nodes_.push_back(node->id()); + VLOG(5) << "removing redundant merge: " << node->name(); + while (!node->out_edges().empty()) { + const Edge* oe = *node->out_edges().begin(); + Node* dst_node = oe->dst(); + int dst_port = oe->dst_input(); + graph_->RemoveEdge(oe); + graph_->AddEdge(non_dead_edge->src(), + dst_port == Graph::kControlSlot + ? Graph::kControlSlot + : non_dead_edge->src_output(), + dst_node, dst_port); + } + return Status::OK(); +} + +Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { + // Handle redundant switch nodes. A switch node is considered redundant if + // the predicate of the switch already holds on the current branch. E.g., if + // p is the predicate of the switch but p is already known to hold on this + // branch, then the switch can be removed and the dead state propagated + // along one. The checking of predicate is based on the exact predicate + // (rather than boolean equivalence) and aimed at redundant switches as + // currently generated by gradient code. + OutputTensor pred; + TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); + auto dst_id = cond_state_map_.LookupId(node); + BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is + // true/false. + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + + VLOG(5) << "Redundant switch " << node->name(); + const Edge* value_edge; + TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); + Node* val_node = value_edge->src(); + int val_port = value_edge->src_output(); + while (!node->out_edges().empty()) { + auto e = *node->out_edges().begin(); + Node* dst_node = e->dst(); + int dst_input = e->dst_input(); + int switch_branch = e->src_output(); + graph_->RemoveEdge(e); + if (switch_branch == Graph::kControlSlot) { + if (IsMerge(dst_node)) { + auto id_or = + JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst_node)); + cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + } else { + auto id_or = + JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + TF_RETURN_IF_ERROR(id_or.status()); + cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + } + } else if (BranchType(switch_branch) != b) { + cond_state_map_.MarkDead(dst_node); + delete_nodes_.push_back(dst_node->id()); + continue; + } + graph_->AddEdge( + val_node, + switch_branch == Graph::kControlSlot ? Graph::kControlSlot : val_port, + dst_node, dst_input); + } + return Status::OK(); +} + +Status FunctionalizeCond::DetermineCondStates( + std::vector rev_topo_order) { + // The state that is propagated along the given edge. + for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { + Node* dst = *it; + TF_RETURN_IF_ERROR(DetermineCondState(dst)); + if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); + if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); + + VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + } + return Status::OK(); +} + +void FunctionalizeCond::DeleteReachableNodes() { + // Delete all nodes that have been extracted or are reachable from + // deleted/dead nodes. The input and outgoing edges should have already been + // removed. + std::vector deleted(graph_->num_node_ids(), false); + // Don't try to delete source or sink nodes. + deleted[graph_->kSourceId] = true; + deleted[graph_->kSinkId] = true; + while (!delete_nodes_.empty()) { + int d_id = delete_nodes_.front(); + delete_nodes_.pop_front(); + if (deleted[d_id]) continue; + Node* d = graph_->FindNodeId(d_id); + // Switch and Merge nodes could have been deleted already. + if (d == nullptr) continue; + for (const Edge* e : d->out_edges()) { + delete_nodes_.push_back(e->dst()->id()); + } + deleted[d_id] = true; + graph_->RemoveNode(d); + } +} + +void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { + // Sort merge nodes by nesting depth. + using sort_pair = std::pair; + std::vector inner_to_outer_merge_order; + inner_to_outer_merge_order.reserve(merge_order->size()); + for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { + Node* merge = *it; + CondStateMap::CondId id = cond_state_map_.LookupId(merge); + int depth = 0; + for (auto cond_node_it = id->begin(); cond_node_it != id->end(); + ++cond_node_it) { + if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && + (cond_node_it->branch == BranchType::kThenBranch || + cond_node_it->branch == BranchType::kElseBranch)) { + ++depth; + } + } + inner_to_outer_merge_order.emplace_back(depth, merge); + } + std::stable_sort( + inner_to_outer_merge_order.begin(), inner_to_outer_merge_order.end(), + [](sort_pair lhs, sort_pair rhs) { return lhs.first > rhs.first; }); + merge_order->clear(); + for (sort_pair t : inner_to_outer_merge_order) { + merge_order->push_back(t.second); + } +} + +Status FunctionalizeCond::FunctionalizeInternal() { + // The general approach for converting a tf.cond (as lowered via switch/merge + // nodes) to a functional if is as follows: + // 1. Determine the topological order and collect all the switch and merge + // nodes in the graph; + // 2. Compute the predicates and dominance structure for all the nodes in the + // graph - this includes which predicate must be true for a op to execute + // (predicate values are considered directly rather than attempting to + // determine deeper equivalence). We shall refer to this structure as the + // CondState; + // 3. Sort the merge nodes by nesting depth; + // 4. Extract merge nodes together that have the same CondState and whose + // input nodes have the same state from the innermost to the outermost into + // IfOps; Note: In the above only nodes paths that converge to a merge node + // will be considered for removal. + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Record reverse topological for merge and switch nodes; + std::vector rev_topo_order; + std::vector switch_ids; + std::vector merge_order; + DFS(*graph_, nullptr, [&](Node* n) { + if (IsSwitch(n)) { + switch_ids.push_back(n->id()); + } + if (IsMerge(n)) { + merge_order.push_back(n); + } + if (n->IsOp()) { + rev_topo_order.push_back(n); + } + }); + + // No merges to functionalize. + if (merge_order.empty()) { + // No merges mean no switch values consumed (as only considering values + // fetchable as output of merge); + for (auto it = switch_ids.begin(); it != switch_ids.end(); ++it) { + graph_->RemoveNode(graph_->FindNodeId(*it)); + } + return Status::OK(); + } + + TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); + + if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + + // Sort the merge nodes from innermost outwards. + SortMergeNodes(&merge_order); + + // Extract from innermost out. + for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { + Node* merge = *it; + auto id = cond_state_map_.LookupId(merge); + if (cond_state_map_.IsDead(id)) continue; + + // Construct a Conditional with the predicate of the merge (which is the + // last entry of the CondState for the merge) and this as parent. + DCHECK(id->back().predicate.node != nullptr); + Conditional cond(id->back().predicate, this, &cond_state_map_); + TF_RETURN_IF_ERROR(cond.AddMerge(merge)); + + // Find all merge nodes with the same CondId. This is done repeatedly as + // the CondId can change due replaced conditionals. E.g., the one branch + // could previously have had a conditional nested in it, and so would have + // had CondState with sub-state [switch(p,b),m] (where p is some predicate), + // post removing the nested conditional that sub-state would no longer be + // path of the propagated state along that path. + auto end = merge_order.end(); + for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; + ++merge_candidate_it) { + auto merge_candidate_it_id = + cond_state_map_.LookupId(*merge_candidate_it); + if (merge_candidate_it_id != id) continue; + TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + } + + TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); + + if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); + } + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) delete_nodes_.push_back(s_id); + for (Node* m : merge_order) delete_nodes_.push_back(m->id()); + DeleteReachableNodes(); + + return Status::OK(); +} + +void FunctionalizeCond::DumpGraphWithCondState(const string& name) { + const char* const kCondGroupDebugAttr = "_XlaFunctionalizeCondGroup"; + + for (Node* n : graph_->nodes()) { + n->ClearAttr(kCondGroupDebugAttr); + n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + } + LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " + << dump_graph::DumpGraphToFile( + strings::StrCat("functionalize_", name), *graph_, library_); +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + return fc.FunctionalizeInternal(); +} + +} // namespace functionalize_cond + +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) { + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + return functionalize_cond::FunctionalizeCond::Functionalize(graph, library); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h new file mode 100644 index 0000000000000000000000000000000000000000..86436011c6ebdc608a5811a1b0d6a10015d405bd --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ + +#include +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Functionalize all the switch-merge nodes of a loop-free graph into If +// nodes. That is, attempt to transform every remaining switch and merge nodes +// in the graph into If nodes. +// Precondition: All while loops have been removed from graph. +Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + +// Internal functions/classes exposed for testing purposes. +namespace functionalize_cond { + +// All nodes are assumed to be either in no branch, then branch, else branch, +// or both branches (such as merge nodes). +// The code below relies on Else and Then being 0 and 1 (corresponding to the +// switch outputs). Both and Neither are arbitrary. +enum class BranchType { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, +}; + +// CondStateMap is responsible for mapping from each graph Node to a CondState, +// where each CondState is the array of CondNodes (corresponding to switch, +// merge or dead states) as described below. For efficiency, this class interns +// the CondState, so that CondState equality comparisons are simply pointer +// comparisons. +class CondStateMap { + public: + explicit CondStateMap(Graph* graph); + + // Represents an entry in the CondState. An entry can either be the + // switch (along with predicate), merge, or dead: + // * switch node indicates a node that is executed along a branch with the + // given predicate - a branch can be then, else or both; + // * merge node indicates that the node is executed as output of a merge; + // * dead indicates that this node can never be executed; + struct CondNode { + enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; + + CondNode(Type type, Node* switch_node = nullptr, + BranchType branch = BranchType::kNeither); + + string ToString() const; + bool operator==(const CondNode& other) const; + bool operator!=(const CondNode& other) const; + + // Type of node. + Type type; + + // Predicate and branch, only used when type is kSwitch. + OutputTensor predicate; + BranchType branch; + }; + + // A node in the graph is executed when multiple conditions hold. The order + // represents the nesting of the predicates that hold and is used when + // extracting the nested conditionals. + using CondState = std::vector; + + // Every unique ID is mapped to a CondState. + using CondId = const CondState*; + + // Returns the CondId for a given node. + CondId LookupId(const Node* node) const; + + // Returns the unique CondId for CondState. + CondId GetUniqueId(const CondState& state); + + // Returns the CondState for a Node. + // REQUIRES: node has a non-empty CondState. + const CondState& LookupState(const Node* node) const; + + // Resets the CondId for a given node. + void ResetId(const Node* node, CondId id); + + // Marks `node` as dead. + void MarkDead(const Node* node); + + // Determine branch execution of CondState. + BranchType FindBranchOf(CondId id, OutputTensor predicate) const; + + // Enum to represent whether one cond flow state contains another. + enum ContainsResult { + kIncomparable, + kEqual, + kLhsContainsRhs, + kRhsContainsLhs + }; + + // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., + // [(p,t)] contains [(p,t), (r,t)]. + ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); + + // Returns textual representation of node's CondState. + string CondStateToString(const Node* node) const; + string CondStateToString(CondId id) const; + + // Returns whether the cond state is the dead state. + bool IsDead(CondId id) const; + + // Returns whether the cond state is the empty state. + bool IsEmpty(CondId id) const; + + // Computes the predicates that have to hold for a node to execute and returns + // whether it was possible to determine the predicates that must hold. `scope` + // is populated with these predicates. Scope differs from state in that it + // does not include merge and both nodes. + bool ScopeIn(CondId id, CondId* scope); + + private: + // Hash for CondNode and CondState. + struct CondHash { + size_t operator()(const CondNode& item) const; + size_t operator()(const CondState& vec) const; + }; + + // Set to keep track of unique CondStates. + // Pointers to the entries in the unordered set are used as identifiers: + // unordered_set guarantees that the pointers remain the same. + std::unordered_set condstate_set_; + + // Mapping from Node id to CondId. + std::vector node_to_condid_map_; + + // Track the CondId for newly inserted nodes. We use a vector to quickly map + // from Node id in the original graph to the CondId, but there will be nodes + // added to the original graph (such as If nodes) whose CondState needs to be + // tracked too. + std::unordered_map added_node_mapping_; + + // Identifier of the dead flow state. The empty flow state is represented with + // a nullptr. + CondId dead_id_; +}; + +// FunctionalizeCond groups all the state used by functionalizing conditionals +// of the given graph together. +class FunctionalizeCond { + public: + // Functionalize all the switch-merge nodes of a loop-free graph into If + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into If nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + // Build identity node with the same name as the merge that will be replaced + // in case the output is fetched/colocated. + Status AddIdentityNode(const Node* replacee, Node* if_node, int port); + + // Add a If node to the graph defined by def that will, amongst other, replace + // replacee in the graph. + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee); + + // Propagates the state of a newly inserted node. + Status PropagateUpdatedState(const Node* replacee); + + // Dump graph with the CondState annotated. + void DumpGraphWithCondState(const string& name); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); + + // Performs the actual cond functionalization. Iterate over groups of merge + // nodes (linked by common predicate & CondIds of the incomming edges), + // from innermost to outermost, and extract into If nodes. + Status FunctionalizeInternal(); + + // Returns the forward flow state propagated along edge `e`. + // This may modify cond_state_map_. + CondStateMap::CondId StateAlongEdge(const Edge* e); + + // Determines the CondState of all the nodes in the given vector where + // the input is expected in reverse topological order. + // This populates the cond_state_map_. + Status DetermineCondStates(std::vector rev_topo_order); + + // Determine the CondState for a given node using the incomming edges + // to the node. Note: it is expected that this node's CondState is only + // determined once its input's CondState is. + Status DetermineCondState(Node* dst); + + // Helper functions for DetermineCondState. + Status DetermineCondStateMerge(Node* dst); + + // Helper functions for DetermineCondStates. Determines the dst node's + // CondState by joining the src and dst's CondState where either + // the dst node is a merge or not. + // These may modify cond_state_map_. + xla::StatusOr JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst); + + // Checks if a merge node is redundant and if so removes it from the graph. + Status RemoveRedundantMerge(Node* node); + + // Checks if a switch node is redundant and if so removes it from the graph. + Status RemoveRedundantSwitch(Node* node); + + // Sorts merge nodes (in reverse topological order) in order of increasing + // nesting depth. + void SortMergeNodes(std::vector* merge_order); + + // Deletes all nodes in/consumers of `delete_nodes_`. + void DeleteReachableNodes(); + + // Member used to unique the CondState to a unique CondId and keep track of + // CondState/CondId per Node. + CondStateMap cond_state_map_; + + // Nodes to be deleted. + std::deque delete_nodes_; + + FunctionLibraryDefinition* library_; + Graph* graph_; + + friend class FunctionalizeCondTest; +}; + +} // namespace functionalize_cond + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a27f8893925855f536801a8a68855b82ac07462d --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for the backward const analysis. + +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace functionalize_cond { + +class FunctionalizeCondTest : public ::testing::Test { + protected: + FunctionalizeCondTest() { + graph_.reset(new Graph(OpRegistry::Global())); + flib_def_.reset( + new FunctionLibraryDefinition(OpRegistry::Global(), fdef_lib_)); + fc_.reset(new functionalize_cond::FunctionalizeCond(graph_.get(), + flib_def_.get())); + } + + CondStateMap::CondId GetUniqueId( + const CondStateMap::CondStateMap::CondState& state) { + return fc_->cond_state_map_.GetUniqueId(state); + } + + xla::StatusOr JoinCondStatesNonMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); + } + + xla::StatusOr JoinCondStatesMerge( + CondStateMap::CondId src, CondStateMap::CondId dst) { + return fc_->JoinCondStatesMerge(src, dst); + } + + bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { + return fc_->cond_state_map_.ScopeIn(ff, scope); + } + + CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( + CondStateMap::CondId lhs, CondStateMap::CondId rhs) { + return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + } + + FunctionDefLibrary fdef_lib_; + std::unique_ptr fc_; + std::unique_ptr flib_def_; + std::unique_ptr graph_; +}; + +namespace { + +TEST_F(FunctionalizeCondTest, ScopeIn) { + Tensor pred_tensor(DT_BOOL, TensorShape()); + pred_tensor.flat().setZero(); + Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* s = test::graph::Switch(graph_.get(), val, pred); + + { + CondStateMap::CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + CondStateMap::CondId id = GetUniqueId(ss); + CondStateMap::CondId scope; + ASSERT_TRUE(ScopeIn(id, &scope)); + ASSERT_TRUE(id == scope); + } + + CondStateMap::CondState empty; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); + ss.emplace_back( + CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); + CondStateMap::CondId id = GetUniqueId(ss); + CondStateMap::CondId scope_1; + ASSERT_TRUE(ScopeIn(id, &scope_1)); + ASSERT_TRUE(scope_1 == GetUniqueId(empty)); + ASSERT_TRUE(id != scope_1); + + ss.clear(); + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); + id = GetUniqueId(ss); + CondStateMap::CondId scope_2; + ASSERT_TRUE(ScopeIn(id, &scope_2)); + + ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == + CondStateMap::ContainsResult::kLhsContainsRhs); + } +} + +TEST_F(FunctionalizeCondTest, JoinCondStates) { + Tensor pred_tensor(DT_BOOL, TensorShape()); + pred_tensor.flat().setZero(); + Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); + Tensor val_tensor(DT_INT32, TensorShape()); + val_tensor.flat().setZero(); + Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); + Node* s = test::graph::Switch(graph_.get(), val, pred); + + CondStateMap::CondId empty = GetUniqueId({}); + + CondStateMap::CondId then_branch; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + then_branch = GetUniqueId(ss); + } + CondStateMap::CondId else_branch; + { + CondStateMap::CondState ss; + ss.emplace_back(CondStateMap::CondNode( + CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + else_branch = GetUniqueId(ss); + } + + // An non-merge op with inputs from then and else branch. + Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); + EXPECT_TRUE(errors::IsInvalidArgument(status)); + + // Merge between then and else branch. + auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + TF_EXPECT_OK(joined_or.status()); + CondStateMap::CondId joined = joined_or.ValueOrDie(); + + // Merge between then branch and both branch. + auto t = JoinCondStatesNonMerge(then_branch, joined); + // Note: this is OK in terms of constraint predication, but + TF_EXPECT_OK(t.status()); + + // Post merge the propagated forward flow state has an additional merge. + CondStateMap::CondId post_merge; + { + CondStateMap::CondState ss; + ss = *joined; + ss.emplace_back( + CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); + post_merge = GetUniqueId(ss); + } + + t = JoinCondStatesNonMerge(post_merge, joined); + TF_EXPECT_OK(t.status()); + EXPECT_TRUE(joined == t.ValueOrDie()); + + // No predicate that results in two paths predicated on different conditions + // merge. + t = JoinCondStatesMerge(post_merge, joined); + EXPECT_FALSE(t.ok()); + + // Post the merge we are effectively in the root scope and merging should + // result in the more restrictive post merge state. + t = JoinCondStatesNonMerge(post_merge, empty); + TF_EXPECT_OK(t.status()); + EXPECT_TRUE(post_merge == t.ValueOrDie()); +} + +} // namespace +} // namespace functionalize_cond +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 0904778f97c95628c81054cd4bc2ff32ff440a33..5932be4e525dec11a8f3c59bb85e0449e76e79c0 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,1440 +21,24 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" -#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/graph/node_builder.h" namespace tensorflow { -namespace { - -using xla::StatusOr; - -const char* const kArgOp = "_Arg"; -const char* const kRetValOp = "_Retval"; - -// Information about a loop argument. -struct Arg { - // Every loop argument has an Enter node. - Node* enter; - - // Is the loop argument a loop-invariant value? Taken from the `is_constant` - // attribute on the Enter node. - bool is_loop_invariant; - - // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant - // arguments must have all of the following nodes: - Node* merge = nullptr; - Node* switch_node = nullptr; - Node* next_iteration = nullptr; - Node* exit = nullptr; -}; - -// Information about a loop frame. -struct Frame { - string name; - - // Pointer to the parent frame. The root frame has a pointer to itself. - Frame* parent = nullptr; - int num_children = 0; - - // Arguments to this loop. - std::vector args; - - // The loop condition of the loop. There should be exactly one loop condition - // in every loop. - Node* loop_cond = nullptr; - - // Set of nodes that belong to the loop frame. - std::unordered_set nodes; -}; - -// Comparison function used for sorting nodes consistently. -// a) resource variables are last, and -// b) sort lexicographically by name (for deterministic output). -struct NodeCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } -}; - -// Returns a textual representation of the names of the nodes in the input. -template -string NodesToString(const T& nodes) { - return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); -} - -// Copies a subgraph from `graph` to `output` by performing a reverse DFS -// starting at nodes in vector `stack`. -// `node_map` is a vector indexed by source node ID to dest nodes. -// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. If a frame is provided (frame -// != nullptr), then this functions will return an error if the -// traversal leaves 'frame'; the client must add enough nodes to `node_map` to -// cut the graph and prevent the traversal from escaping. -// -// `squash_src_outputs` contains a bool for each source node ID. If true, then -// the source output on that node will be replaced by zero when copied. This is -// used when replacing a Switch node with an _Arg node. The output we are -// taking from the Switch node was not necessarily the first output, but _Arg -// nodes only have one output. By adding the Switch node to `squash_src_outputs` -// we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame* frame, - std::vector stack, - const std::vector& squash_src_outputs, - std::vector* node_map, Graph* output) { - VLOG(3) << "Stack: " << NodesToString(stack); - std::vector visited(graph.num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - VLOG(5) << "Copying node " << n->name(); - - if (visited[n->id()]) continue; - visited[n->id()] = true; - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { - // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame->name, - " escaped frame at ", src->name(), - " without encountering an argument node."); - } - if ((*node_map)[src->id()] == nullptr) { - (*node_map)[src->id()] = output->CopyNode(src); - stack.push_back(src); - } - Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() - ? 0 - : e->src_output(); - Node* dst_copy = (*node_map)[e->dst()->id()]; - output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); - } - } - return Status::OK(); -} - -StatusOr AddNode(const NodeDef& node_def, Graph* graph) { - Status status; - Node* inserted_node = graph->AddNode(node_def, &status); - if (!status.ok()) { - return status; - } - return inserted_node; -} - -// Check that the graph has no cycle containing the given node. -Status CheckNoCycleContains(const Node* node, const int num_nodes) { - std::vector ready; - ready.push_back(node); - std::vector visited(num_nodes); - while (!ready.empty()) { - const Node* current_node = ready.back(); - ready.pop_back(); - visited[current_node->id()] = true; - for (const Edge* out : current_node->out_edges()) { - if (out->dst() == node) { - return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), - "(", node->def().op(), ") feeds into itself."); - } else if (!visited[out->dst()->id()]) { - ready.push_back(out->dst()); - } - } - } - return Status::OK(); -} - -StatusOr BuildArgNode(Graph* graph, DataType type, int index) { - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); - builder.Attr("T", type); - builder.Attr("index", index); - TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - return AddNode(arg_def, graph); -} - -StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); - AddNodeAttr("T", type, &ret_def); - AddNodeAttr("index", index, &ret_def); - return AddNode(ret_def, graph); -} - -// Builds a graph for the loop condition. -Status BuildLoopCondition(const Graph& graph, Frame* frame, - std::unique_ptr* cond_output) { - VLOG(2) << "Building loop condition for " << frame->name; - *cond_output = xla::MakeUnique(graph.op_registry()); - Graph* output = cond_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - TF_ASSIGN_OR_RETURN(Node * arg_node, - BuildArgNode(output, arg.enter->input_type(0), i)); - if (arg.is_loop_invariant) { - node_map[arg.enter->id()] = arg_node; - } else { - node_map[arg.merge->id()] = arg_node; - } - } - - // Build a Retval node for the loop condition. The LoopCond nodes are always - // boolean because of the type constraints on the LoopCond op. - TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], - BuildRetvalNode(output, DT_BOOL, 0)); - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, - &node_map, output); -} - -// Builds a graph for the loop body. -Status BuildLoopBody(const Graph& graph, Frame* frame, - DataTypeVector* arg_types, - std::unique_ptr* body_output) { - VLOG(2) << "Building loop body for " << frame->name; - *body_output = xla::MakeUnique(graph.op_registry()); - Graph* output = body_output->get(); - - // Map from nodes in the original graph to the condition graph. - std::vector node_map(graph.num_node_ids(), nullptr); - std::vector squash_src_outputs(graph.num_node_ids(), false); - - // Build one _Arg node for each Enter node. - std::vector next_iterations; - next_iterations.reserve(frame->args.size()); - arg_types->reserve(frame->args.size()); - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - - DataType dtype = arg.enter->input_type(0); - arg_types->push_back(dtype); - - TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); - - if (dtype == DT_RESOURCE) { - // The convention of the XLA bridge is that resource variable arguments - // are only inputs to the loop body and have no corresponding output. - // TODO(b/37741920): change the convention so that DT_RESOURCE variables - // are both inputs and outputs, and then remove this case. - TF_RET_CHECK(arg.is_loop_invariant); - node_map[arg.enter->id()] = arg_node; - } else { - TF_ASSIGN_OR_RETURN(Node * retval_node, - BuildRetvalNode(output, dtype, i)); - - if (arg.is_loop_invariant) { - // Argument is loop-invariant. Forward it from the Arg to the Retval. - node_map[arg.enter->id()] = arg_node; - output->AddEdge(arg_node, 0, retval_node, 0); - } else { - // Argument is loop-varying. - node_map[arg.switch_node->id()] = arg_node; - // The Switch node has two outputs, but _Arg only has one. This tells - // the CopySubgraph function to rewrite the output number of edges from - // the _Arg node to be 0 rather than copying the output number from the - // Switch node. - squash_src_outputs[arg.switch_node->id()] = true; - node_map[arg.next_iteration->id()] = retval_node; - next_iterations.push_back(arg.next_iteration); - } - } - } - - // Performs a reverse DFS, copying nodes and edges to the output graph. - // The _Arg and _Retval nodes were added unconditionally above, so we are - // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), - squash_src_outputs, &node_map, output)); - - return Status::OK(); -} - -// Copy the FunctionDef of given function from lookup_library to library, if -// it can be found in lookup_library but is missing from library. -Status AddMissingFunctionByName(const string& function_name, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - if (!library->Find(function_name) && lookup_library->Find(function_name)) { - return library->AddFunctionDef(*lookup_library->Find(function_name)); - } - return Status::OK(); -} - -// Iterate over all functions that the given fdef refers to. Copy the missing -// FunctionDefs from lookup_library to library. -Status AddMissingFunctionDef(const FunctionDef& fdef, - const FunctionLibraryDefinition* lookup_library, - FunctionLibraryDefinition* library) { - TF_RET_CHECK(lookup_library); - for (const NodeDef& node : fdef.node_def()) { - if (library->Find(node.op())) { - continue; - } - // The function referred by 'SymbolicGradient' node is specified in its - // attribute 'f'. - if (node.op() == FunctionLibraryDefinition::kGradientOp) { - const AttrValue* attr = - AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); - if (!attr) { - return errors::InvalidArgument("SymbolicGradient is missing attr: f"); - } - const string& func_name = attr->func().name(); - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(func_name, lookup_library, library)); - // Copy the user-defined gradient function if it exists. - const string grad_name = lookup_library->FindGradient(func_name); - if (!grad_name.empty() && library->FindGradient(func_name).empty()) { - TF_RETURN_IF_ERROR( - AddMissingFunctionByName(grad_name, lookup_library, library)); - GradientDef grad_def; - grad_def.set_function_name(func_name); - grad_def.set_gradient_func(grad_name); - TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); - } - } else if (lookup_library->Find(node.op())) { - TF_RETURN_IF_ERROR( - library->AddFunctionDef(*lookup_library->Find(node.op()))); - } - } - return Status::OK(); -} - -Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, Frame* frame, - FunctionLibraryDefinition* library) { - VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph, - library); - - // Split loop-varying Enter nodes with multiple successors. If the same - // Tensor is fed as input to multiple loop arguments, we may end up with a - // shared Enter node. We clone Enter nodes with multiple successors to - // maintain the invariant of a unique Enter node per argument of the final - // loop. - std::vector args; - for (const Arg& arg : frame->args) { - if (arg.is_loop_invariant) { - args.push_back(arg); - } else { - std::vector edges(arg.enter->out_edges().begin(), - arg.enter->out_edges().end()); - for (int i = 0; i < edges.size(); ++i) { - if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { - continue; - } - TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); - Arg new_arg; - new_arg.is_loop_invariant = false; - if (i == 0) { - new_arg.enter = arg.enter; - } else { - new_arg.enter = graph->CopyNode(arg.enter); - frame->nodes.insert(new_arg.enter); - for (Edge const* e : arg.enter->in_edges()) { - graph->AddEdge(e->src(), e->src_output(), new_arg.enter, - e->IsControlEdge() ? Graph::kControlSlot : 0); - } - Node* dst = edges[i]->dst(); - int dst_input = edges[i]->dst_input(); - graph->RemoveEdge(edges[i]); - graph->AddEdge(new_arg.enter, 0, dst, dst_input); - } - args.push_back(new_arg); - } - } - } - frame->args = std::move(args); - - std::sort( - frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); - - if (frame->loop_cond == nullptr) { - return errors::InvalidArgument("Loop ", frame->name, - " has no LoopCond node"); - } - - // Find the set of Switch nodes that are successors of the LoopCond. - std::unordered_set switches; - for (const Edge* edge : frame->loop_cond->out_edges()) { - if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && - edge->dst_input() == 1) { - switches.insert(edge->dst()); - } - } - - // For each non-constant argument, looks for the following pattern of nodes: - // Enter ----> Merge --------> Switch --> Exit - // ^ ^ - // | | - // NextIteration LoopCond - // ^ ^ - // | | - // ... ... - for (Arg& arg : frame->args) { - if (!arg.is_loop_invariant) { - // Follow the edge from the Enter to Merge. - const Edge* enter_merge = nullptr; - for (const Edge* e : arg.enter->out_edges()) { - // Ignore control-edges to the sink node. These are allowed by the - // graph invariants, although probably they should have been stripped - // off earlier. - if (e->IsControlEdge() && e->dst()->IsSink()) { - continue; - } - if (enter_merge != nullptr) { - return errors::Internal("Enter node for loop-varying argument ", - FormatNodeForError(*arg.enter), - " has multiple successors: ", - FormatNodeForError(*enter_merge->dst()), - " and ", FormatNodeForError(*e->dst())); - } - enter_merge = e; - } - if (enter_merge == nullptr) { - return errors::Internal("Enter node for loop-varying argument ", - FormatNodeForError(*arg.enter), - " has zero successors"); - } - arg.merge = enter_merge->dst(); - if (!IsMerge(arg.merge)) { - return errors::InvalidArgument( - "Successor of Enter node for loop-varying argument ", - FormatNodeForError(*arg.merge), - " is not a Merge node; got: ", arg.merge->type_string()); - } - - // Find the NextIteration from the merge. There should be two inputs to - // the Merge and the NextIteration should be the other input. - if (arg.merge->input_types().size() != 2) { - return errors::InvalidArgument( - "Unexpected number of inputs to Merge node for loop-varying " - "argument ", - FormatNodeForError(*arg.merge), "; expected 2, got ", - arg.merge->input_types().size()); - } - TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), - &arg.next_iteration)); - if (!IsNextIteration(arg.next_iteration)) { - return errors::InvalidArgument( - "Expected NextIteration node as input to Merge node; got node ", - FormatNodeForError(*arg.next_iteration), " with kind ", - arg.next_iteration->type_string()); - } - - // Find the Switch successor of the Merge. There should be exactly one - // Switch node that is a successor of both the Merge and the LoopCond. - for (const Edge* edge : arg.merge->out_edges()) { - if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && - switches.find(edge->dst()) != switches.end()) { - if (arg.switch_node != nullptr) { - return errors::InvalidArgument("Duplicate Switch successors to ", - FormatNodeForError(*arg.merge)); - } - arg.switch_node = edge->dst(); - } - } - if (arg.switch_node == nullptr) { - return errors::InvalidArgument("Missing Switch successor to ", - FormatNodeForError(*arg.merge)); - } - - // Update the device on the Identity outputs of the switch to match their - // target. These Identity outputs do not - - // Loop over the switch node's output to: - // - Find the Exit successor. - // - Set the sharding on all Identity outputs of the switch. These - // identity nodes are values used by the loop body or condition. - // The Identity node may have the wrong device so copy the device from - // one of its outputs instead. - std::deque possible_exit; - for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0) { - possible_exit.push_back(edge); - } - if (IsIdentity(edge->dst())) { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); - } - } - // TODO(b/67425339): Allow general graph between switch and exit. - while (!possible_exit.empty()) { - const Edge* edge = possible_exit.front(); - possible_exit.pop_front(); - if (IsExit(edge->dst())) { - if (arg.exit != nullptr) { - return errors::InvalidArgument( - "Duplicate Exit successors to ", - FormatNodeForError(*arg.switch_node)); - } - arg.exit = edge->dst(); - } else { - if (!IsIdentity(edge->dst())) { - return errors::Unimplemented("General graph between switch (", - FormatNodeForError(*arg.switch_node), - ") and exit node of frame ", - frame->name, " not supported yet."); - } - for (const Edge* out : edge->dst()->out_edges()) { - possible_exit.push_back(out); - } - } - } - } - } - - // Builds the condition and body functions. - std::unique_ptr cond_graph; - TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); - DataTypeVector arg_types; - std::unique_ptr body_graph; - TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); - - VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) - << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); - - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); - NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); - FunctionDef cond_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); - - TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); - if (lookup_library) { - // Copy missing FunctionDefs from lookup_library to library to make library - // self-contained. - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(cond_fdef, lookup_library, library)); - TF_RETURN_IF_ERROR( - AddMissingFunctionDef(body_fdef, lookup_library, library)); - } - - // Builds a While operator. - NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); - builder.Attr("T", arg_types); - builder.Attr("cond", cond_name); - builder.Attr("body", body_name); - std::vector inputs; - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - inputs.push_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), arg_types[i])); - } - } - builder.Input(inputs); - TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); - - // Copies edges to the Enter nodes and from the Exit nodes onto the While. - for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph->AddControlEdge(in_edge->src(), while_node); - } else { - graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); - } - - if (!arg.is_loop_invariant) { - // Add output edges if the output of the loop is consumed. - if (arg.exit != nullptr) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (dst_input == Graph::kControlSlot) { - graph->AddControlEdge(while_node, dst); - } else { - graph->AddEdge(while_node, i, dst, dst_input); - } - } - } - } - } - - // Remove the old nodes from the graph, and add the while node to the parent - // frame. - for (Node* node : frame->nodes) { - graph->RemoveNode(node); - } - frame->nodes.clear(); - frame->parent->nodes.insert(while_node); - - VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph, - library); - - return Status::OK(); -} - -class FunctionalizeCond { - public: - // All nodes are assumed to be either in no branch, then branch, else branch, - // or both branches (such as merge nodes). - enum Branch { - kElseBranch = 0, - kThenBranch = 1, - kBoth = 2, - kNeither = 3, - kNumBranchTypes = 4 - }; - - // Returns a textual representation of the Branch b. - static string Branch_Name(FunctionalizeCond::Branch b); - - // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf - // nodes. That is, attempt to transform every remaining switch and merge nodes - // in the graph into XlaIf nodes. - // Precondition: All while loops have been removed from graph. - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: - // CondArgNode represents a input to the conditional and its corresponding - // switch nodes. - struct CondArgNode { - explicit CondArgNode(Node* src, int src_output) - : src(src), src_output(src_output) {} - string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); - } - - Node* src; - int src_output; - std::vector switches; - }; - using CondArgNodes = std::vector; - - struct ForwardFlowNode { - explicit ForwardFlowNode(Branch branch = Branch::kNeither) - : branch(branch), count(0) {} - string ToString() const { - return strings::StrCat("branch=", Branch_Name(branch), " count=", count); - } - Branch branch; - int count; - }; - - // Group of switch nodes that will be part of the same XlaIf. - struct SwitchCluster { - explicit SwitchCluster(const Edge* predicate_edge) - : predicate_edge(predicate_edge) {} - string ToString() const { - return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), - " switches=", NodesToString(switches)); - } - - string name; - const Edge* predicate_edge; - std::vector switches; - }; - - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, - bool dump_graphs) - : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} - - // Perform the actual cond functionalization. Iterate over groups of switch - // nodes (linked by common predicate), from innermost to outermost, and - // extract into XlaIf nodes. - Status FunctionalizeInternal(); - - // Determines the branch_map (mapping from node to branch of cond) and - // frontier (the nodes where the cond ends). - StatusOr, - std::unordered_set>> - DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); - - // Returns XlaIf node created from subgraph of merge and switch nodes. This - // encapsulates the process of extracting the bodies needed for the then and - // else branch, creates a XlaIf node, removing the nodes of the branches from - // the graph and replacing the merge node with a XlaIf. - StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& switches); - - // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. - StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const SwitchCluster& switch_cluster, - const std::vector& merge_nodes); - - // Extracts a function body corresponding to the given input edge of the merge - // node. - Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, int input_edge, - Graph* body); - - // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, Node* if_node); - - // Adds all output edges from the `if_node`. - Status AddOutputEdges(const std::vector& outputs, Node* if_node); - - // Returns the switch clusters of graph_ in postorder. Dead switch nodes are - // skipped and removed from the graph. - StatusOr> DeterminePredicateSwitchOrder(); - - // Update the state for destination based on the state of source and the node - // being updated. - Status Join(const ForwardFlowNode& src_state, const Node* dst, - ForwardFlowNode* dst_state); - - // Ensure that all nodes in the branch_map are dominated by the switch - // nodes. Returns nodes that are not dominated by the switches but are a - // control dependency of a node in the cond, and remove such control - // dependencies. - StatusOr> EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches); - - // Validates that the frontier of nodes for the conditional - // section are as expected. - Status ValidateFrontier( - const std::unordered_map& branch_map, - const std::unordered_set& frontier); - - FunctionLibraryDefinition* library_; - Graph* graph_; - bool dump_graphs_; -}; - -bool IsDeadSwitch(const Node* node) { - for (const Edge* e : node->out_edges()) { - const Node* dst = e->dst(); - if (!dst->IsIdentity()) { - return false; - } - for (const Edge* ee : dst->out_edges()) { - if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { - return false; - } - } - } - return true; -} - -string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { - const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { - "else", "then", "both", "neither", "count"}; - return branch_name[b]; -} - -Status FunctionalizeCond::ValidateFrontier( - const std::unordered_map& - branch_map, - const std::unordered_set& frontier) { - std::unordered_set pending[kNumBranchTypes]; - for (Node* n : frontier) { - pending[branch_map.at(n).branch].insert(n); - } - TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); - for (const Node* n : pending[kBoth]) { - TF_RET_CHECK(IsMerge(n)) << n->DebugString(); - // Merge nodes may be in then or else branch too - } - int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) - ? kThenBranch - : kElseBranch; - int other = 1 - index; - for (const Node* n : pending[index]) { - if (pending[other].find(n) != pending[other].end()) { - return errors::Internal( - "Node (", n->DebugString().c_str(), - ") in both Else and Then branch should be in Both."); - } - } - // An empty frontier indicates a dead switch. Above we attempt to remove dead - // switch nodes, but not all are removed so don't treat it as an error yet. - // TODO(jpienaar): Find out why dead switch nodes remain. - // if (pending[kBoth].empty() && pending[kThenBranch].empty() && - // pending[kElseBranch].empty()) { - // return errors::Internal("Unexpected empty frontier for switch nodes"); - // } - return Status::OK(); -} - -Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, - const Node* dst, ForwardFlowNode* dst_state) { - TF_RET_CHECK(dst_state->branch != Branch::kBoth && - dst_state->branch != Branch::kNumBranchTypes) - << "Unexpected/Invalid branch type: Merging " - << Branch_Name(src_state.branch) << " with " - << Branch_Name(dst_state->branch); - if (dst_state->branch == Branch::kNeither) { - dst_state->branch = src_state.branch; - } else if (src_state.branch != dst_state->branch && - src_state.branch != Branch::kNeither) { - if (IsMerge(dst)) { - dst_state->branch = Branch::kBoth; - } else { - return errors::Internal("Illegal merge:\n", src_state.ToString(), - " with ", dst_state->ToString(), " for\n", - dst->DebugString()); - } - } - ++dst_state->count; - return Status::OK(); -} - -StatusOr> -FunctionalizeCond::DeterminePredicateSwitchOrder() { - struct Cluster { - bool operator==(const Cluster& other) const { - return representative == other.representative; - } - int representative = -1; - }; - - // Perform a DFS over the graph and - // * Determine the reverse topological order of the nodes (there should be no - // cycles at this point so the post-order numbering corresponds to the - // reverse topological sorting); - // * Identify dead switches; - // * Initialize the cluster's representative; - std::vector> clusters(graph_->num_node_ids()); - std::vector dead_switches; - std::vector switch_order; - std::vector rev_topo_sorted_nodes; - DFS(*graph_, nullptr, [&](Node* n) { - clusters[n->id()].Get().representative = n->id(); - if (IsSwitch(n)) { - if (IsDeadSwitch(n)) { - dead_switches.push_back(n); - } else { - rev_topo_sorted_nodes.push_back(n); - switch_order.push_back(n); - } - } else if (n->IsOp()) { - // Exclude src and sink nodes from further consideration. - rev_topo_sorted_nodes.push_back(n); - } - }); - - std::vector switch_clusters; - // Return early if there are no switches in the graph. - if (switch_order.empty()) { - return switch_clusters; - } - - // Remove all dead switch nodes. - for (Node* n : dead_switches) { - VLOG(2) << "Removing dead switch: " << n->DebugString(); - graph_->RemoveNode(n); - } - - // Identify switch nodes that are part of the same control flow context by - // considering the operands of operations: an operation is part of the same - // control context as its operands unless the operation is a switch. Control - // dependencies are considered part of the same control flow context if the - // switch depth is the same (see comment below). - - // entry_cluster records the input cluster to a switch node. This is used when - // merging with a merge node where the dst's cluster is merged with the entry - // cluster of the merge node's cluster (which corresponds to a switch cluster - // and so has an entry cluster). - std::unordered_map*> entry_cluster; - - // Returns the output cluster of a node. Where the output cluster is cluster - // where the output of the node is used. For non-merge nodes this is simply - // the cluster they are part of, while for merge nodes it is the entry cluster - // of the cluster they are part of (this will correspond to the entry node of - // a switch node that dominates the merge). - auto find_output_cluster = [&](Node* n) { - UnionFind* cluster = &clusters[n->id()]; - if (!IsMerge(n)) return cluster; - auto it = entry_cluster.find(clusters[n->id()].Get().representative); - // If the cluster is not found in the entry_cluster map then an - // instruction not dominated by a switch node has been merged into the - // cluster of the merge. This indicates a failure of the clustering. - CHECK(it != entry_cluster.end()) - << "Unable to find entry for n=" << n->id() << " (" - << cluster->Get().representative << ")"; - return it->second; - }; - - // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. - std::vector switch_depth(graph_->num_node_ids()); - for (auto it = rev_topo_sorted_nodes.rbegin(); - it != rev_topo_sorted_nodes.rend(); ++it) { - Node* n = *it; - - // Compute switch depth. - int new_switch_depth = 0; - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - new_switch_depth = std::max( - new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); - } - switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); - - // Only merge the input operands of a switch. The switch's clustering itself - // is determined by the interaction of the switch's outputs. - if (IsSwitch(n)) { - Node* input; - TF_CHECK_OK(n->input_node(0, &input)); - entry_cluster[n->id()] = find_output_cluster(input); - UnionFind* cluster = entry_cluster[n->id()]; - int cluster_depth = switch_depth[cluster->Get().representative]; - // Merge the inputs of the switch node with one another. This results in - // predicates and control input residing in the same cluster. - for (const Edge* e : n->in_edges()) { - // Only consider the data inputs to the Switch node. - if (e->IsControlEdge()) continue; - - Node* src = e->src(); - UnionFind* src_cluster = find_output_cluster(src); - int src_cluster_depth = switch_depth[src_cluster->Get().representative]; - if (cluster_depth != src_cluster_depth) { - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Switch ('", - n->name(), "') has operands ('", input->name(), "' and '", - src->name(), "') that have different switch depths (", - cluster_depth, " != ", src_cluster_depth, ")"); - } - cluster->Merge(src_cluster); - } - continue; - } - - for (const Edge* e : n->in_edges()) { - Node* src = e->src(); - if (!src->IsOp()) continue; - UnionFind* cluster = find_output_cluster(src); - // Merge a node with its data operands and with its control operands if - // the src and dst are in the same ControlContext. The ControlContext is - // not explicitly available here, and instead the switch depth is used as - // a proxy here. Due to the invariant that control edges can only be from - // a containing scope to an inner scope or from the inner scope to its - // containing scope (for exit nodes), the switch depth will only match if - // the src and dst are in the same ControlContext. Control edges between - // ControlContexts are handled during the extraction. - int src_id = cluster->Get().representative; - int src_depth = switch_depth[src_id]; - if (!e->IsControlEdge() || new_switch_depth == src_depth) { - if (src_depth != new_switch_depth) { - // TODO(b/77601805) remove this when outside_compilation supports - // control flow. - if (str_util::StrContains(src->name(), "outside_compilation") || - str_util::StrContains(n->name(), "outside_compilation")) { - return errors::InvalidArgument( - "outside_compilation is not yet supported within TensorFlow " - "control flow constructs b/77601805"); - } - return errors::InvalidArgument( - "Unable to functionalize control flow in graph: Operand ('", - src->name(), "') and operator ('", n->name(), - "') have different switch depths (", src_depth, - " != ", new_switch_depth, ")"); - } - cluster->Merge(&clusters[n->id()]); - } - } - } - - if (dump_graphs_) { - // Mark the switch cluster each node is part of. - for (Node* n : graph_->nodes()) { - n->ClearAttr("_XlaFunctionalizeSwitchGroup"); - n->AddAttr("_XlaFunctionalizeSwitchGroup", - clusters[n->id()].Get().representative); - } - LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " - << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, - library_); - } - - // Verify all the nodes of a cluster are at the same depth. - std::unordered_map> cluster_to_depth_node; - for (Node* n : graph_->nodes()) { - int depth = switch_depth[n->id()]; - int cluster_rep = clusters[n->id()].Get().representative; - auto it = cluster_to_depth_node.find(cluster_rep); - if (it == cluster_to_depth_node.end()) { - cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); - } else { - if (it->second.first != depth) { - return errors::Internal( - "Illegal clustering created, mismatch in depths:", "\n\t", - n->DebugString(), "(", clusters[n->id()].Get().representative, - ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), - "(", clusters[n->id()].Get().representative, ") at depth ", - it->second.first); - } - } - } - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second.representative)); - } - }; - - // Merge Switch nodes with common predicate. - std::unordered_map, int, Hash> predicate_index; - // The nodes in switch_order are in reverse topological order, but the - // clustered switches need not be (i.e., when considered as a cluster one - // element of a cluster may be later in the topological order than another - // node whose cluster is later in the topological order of clustered - // switches). - for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - const Edge* pred_edge; - TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); - // The predicate can be preceded by a identity node. Look through identity - // nodes to predicate. - while (pred_edge->src()->IsIdentity()) { - TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); - } - auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); - if (predicate_index.find(repr) == predicate_index.end()) { - predicate_index[repr] = switch_clusters.size(); - switch_clusters.emplace_back(pred_edge); - // Generate a name by concatenating with the cluster representative as - // there could be multiple switch clusters with the same predicate. - switch_clusters[predicate_index[repr]].name = strings::StrCat( - pred_edge->src()->name(), "_", repr.second.representative, "_If"); - } - switch_clusters[predicate_index[repr]].switches.push_back(*it); - } - - return switch_clusters; -} - -StatusOr> -FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( - const std::unordered_map& branch_map, - const std::vector& switches) { - std::vector old_control_nodes; - for (const auto& kv : branch_map) { - if (kv.second.count != kv.first->in_edges().size()) { - std::vector delete_edges; - for (const Edge* in : kv.first->in_edges()) { - auto it = branch_map.find(in->src()); - if (it == branch_map.end()) { - if (in->IsControlEdge()) { - old_control_nodes.push_back(in->src()); - delete_edges.push_back(in); - } else { - if (IsSwitch(in->src())) { - if (std::find(switches.begin(), switches.end(), in->src()) == - switches.end()) { - return errors::Internal( - "Unexpected switch node found during flow forward: ", - in->src()->DebugString()); - } - continue; - } - return errors::InvalidArgument( - "Value ", kv.first->name(), "'s input, ", in->src()->name(), - ", is not dominated by switch nodes ", NodesToString(switches)); - } - } - } - // Remove control edges from nodes that are not dominated by the switch - // nodes. New control dependencies will be added between these nodes and - // the XlaIf node inserted. - for (const Edge* e : delete_edges) { - graph_->RemoveEdge(e); - } - } - } - return old_control_nodes; -} - -StatusOr< - std::pair, - std::unordered_set>> -FunctionalizeCond::DetermineBranchMapAndFrontier( - const SwitchCluster& switch_cluster) { - std::unordered_map branch_map; - std::unordered_set frontier; - std::vector stack = switch_cluster.switches; - std::vector visited(graph_->num_node_ids(), false); - while (!stack.empty()) { - Node* n = stack.back(); - stack.pop_back(); - - if (visited[n->id()]) { - continue; - } - visited[n->id()] = true; - - // Propagate branch state along each edge of a switch node. - bool sink_only = true; - for (const Edge* e : n->out_edges()) { - Node* out = e->dst(); - if (!out->IsOp()) { - continue; - } - sink_only = false; - // Propagate branch information. - ForwardFlowNode& ffn = branch_map[out]; - if (IsSwitch(n)) { - int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", - e->DebugString()); - } else { - TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), - " when joining ", e->DebugString()); - } - if (IsMerge(out)) { - if (out->in_edges().size() == ffn.count) { - frontier.insert(out); - } - } else if (!visited[out->id()]) { - stack.push_back(out); - } - } - if (sink_only) { - if (!IsIdentity(n)) { - VLOG(1) << "Feeding into sink: " << n->DebugString(); - } - } - } - - if (dump_graphs_) { - for (const auto& kv : branch_map) { - // Append attribute to the graph if running with logging to make the - // changes clearer in the visualization. - kv.first->AddAttr("_XlaFunctionalizeBranch", - Branch_Name(kv.second.branch)); - } - } - return std::make_pair(std::move(branch_map), std::move(frontier)); -} - -Status FunctionalizeCond::FunctionalizeInternal() { - TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, - DeterminePredicateSwitchOrder()); - - // Iterate from innermost set of clustered switches to outermost, replacing - // matching switch->merge subgraphs with single XlaIf nodes. - for (auto it = predicate_switch_order.rbegin(); - it != predicate_switch_order.rend(); ++it) { - auto& ps = *it; - VLOG(3) << "Flow down from: " << ps.ToString(); - - std::unordered_map branch_map; - std::unordered_set frontier; - TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps)); - - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, - library_); - TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); - - struct Hash { - size_t operator()(const std::pair& item) const { - return Hash64Combine(hash()(item.first), - std::hash()(item.second)); - } - }; - - // Sort the merge and switch nodes using NodeCmp. The switch-nodes are - // further grouped (post sorting) by input to the switch node as in the - // functionalized form each input will be passed in only once. This grouping - // should retain the sorted order. - CondArgNodes cond_arg_nodes; - std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); - std::unordered_map, int, Hash> input_index; - for (Node* switch_node : ps.switches) { - const Edge* e; - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); - std::pair key = std::make_pair(e->src(), e->src_output()); - if (input_index.find(key) == input_index.end()) { - input_index[key] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(key.first, key.second); - } - cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); - } - std::vector merge_nodes(frontier.begin(), frontier.end()); - std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); - - TF_ASSIGN_OR_RETURN(std::vector old_control_nodes, - EnsureDominanceAndReturnNonDominatedControlNodes( - branch_map, ps.switches)); - - TF_ASSIGN_OR_RETURN(Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); - for (Node* old : old_control_nodes) { - graph_->AddControlEdge(old, if_node); - } - - for (auto& del_kv : branch_map) { - graph_->RemoveNode(del_kv.first); - } - for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switches) { - graph_->RemoveNode(node); - } - } - if (dump_graphs_) - LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, - library_); - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(2) << "Build if op for " << switch_cluster.name; - - NodeDef if_def; - // Create a new If node using the name of the merge node. - NodeDefBuilder builder(switch_cluster.name, "XlaIf"); - string branch[] = {"else_branch", "then_branch"}; - for (int i = 0; i < 2; ++i) { - static std::atomic sequence_num(0LL); - int64 id = ++sequence_num; - - NameAttrList body_name; - body_name.set_name( - strings::StrCat("_functionalize_if_", branch[i], "_", id)); - auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, - merge_nodes, i, body.get())); - VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - FunctionDef body_fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); - TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); - builder.Attr(branch[i], body_name); - } - - // Build input type. - std::vector inputs; - DataTypeVector in_arg_types; - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switches) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - if (!inserted) { - DataType dtype = arg->input_type(0); - inputs.emplace_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), dtype)); - in_arg_types.push_back(dtype); - inserted = true; - } - } - } - } - builder.Attr("Tin", in_arg_types); - - // Build output type. - DataTypeVector out_type; - for (const Node* merge : merge_nodes) { - DataType dtype = merge->output_type(0); - out_type.push_back(dtype); - } - builder.Attr("Tout", out_type); - - builder.Attr("Tcond", DT_BOOL); - builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); - // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut( - switch_cluster.predicate_edge->src()->name(), - switch_cluster.predicate_edge->src_output(), - switch_cluster.predicate_edge->src()->output_type(0))); - // ... followed by the other inputs. - builder.Input(inputs); - - TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); - return if_node; -} - -Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switches, - const std::vector& merge_nodes, - int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " - << input_edge; - std::vector squash_src_outputs(graph_->num_node_ids(), false); - std::vector node_map(graph_->num_node_ids(), nullptr); - int arg_count = 0; - for (auto& kv : cond_arg_nodes) { - Node* arg_node = nullptr; - for (const auto* arg : kv.switches) { - DataType dtype = arg->input_type(0); - if (arg_node == nullptr) { - TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); - } - node_map.at(arg->id()) = arg_node; - squash_src_outputs.at(arg->id()) = true; - } - } - - std::vector stack; - stack.reserve(merge_nodes.size()); - for (int j = 0; j < merge_nodes.size(); ++j) { - Node* node = merge_nodes[j]; - TF_ASSIGN_OR_RETURN(node_map.at(node->id()), - BuildRetvalNode(body, node->output_type(0), - /*index=*/j)); - const Edge* in_edge; - TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); - Node* in = in_edge->src(); - if (node_map.at(in->id()) == nullptr) { - node_map.at(in->id()) = body->CopyNode(in); - } - - if (std::find(switches.begin(), switches.end(), in) == switches.end()) { - body->AddEdge(node_map.at(in->id()), in_edge->src_output(), - node_map.at(node->id()), 0); - } else { - body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); - // Don't include input nodes that are already just returned in stack. - continue; - } - stack.push_back(in); - } - - return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, - body); -} - -Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - const Edge* predicate_edge, - Node* if_node) { - VLOG(3) << "AddInputEdges for " << if_node->name(); - int index = 0; - graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, - index++); - for (auto& arg : cond_arg_nodes) { - if (arg.src_output == Graph::kControlSlot) { - graph_->AddControlEdge(arg.src, if_node); - } else { - graph_->AddEdge(arg.src, arg.src_output, if_node, index++); - } - } - return Status::OK(); -} - -Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, - Node* if_node) { - VLOG(3) << "AddOutputEdges for " << if_node->name(); - for (int i = 0; i < outputs.size(); ++i) { - Node* node = outputs[i]; - std::vector edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - - if (edge->src_output() > 0) { - return errors::Unimplemented("Output of index (", edge->src_output(), - ") of merge node ", node->name()); - } - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph_->RemoveEdge(edge); - graph_->AddEdge(if_node, src_output, dst, dst_input); - } - } - return Status::OK(); -} - -StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, - const std::vector& merge_nodes) { - VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " - << NodesToString(merge_nodes); - - // Extract bodies and builds a If operator. - TF_ASSIGN_OR_RETURN( - Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); - TF_RETURN_IF_ERROR( - AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); - // Check that the if_node doesn't feed into itself. - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(if_node, graph_->num_node_ids()), - "ConvertToXlaIf failed."); - - return if_node; -} - -Status FunctionalizeCond::Functionalize(Graph* graph, - FunctionLibraryDefinition* library) { - VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); - return fc.FunctionalizeInternal(); -} - -} // namespace - -// Transformation that converts TensorFlow's graph control flow constructs into -// functional equivalents. -Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library) { - return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); -} - Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library) { @@ -1462,98 +46,26 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); - // Note: BuildControlFlowInfo() requires that the graph's source node is - // connected to all source nodes in the graph. Many graphs violate this - // invariant. - std::vector cf_info; - std::vector unreachable_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes), - "FunctionalizeControlFlow failed"); - if (!unreachable_nodes.empty()) { - return errors::InvalidArgument( - "The following nodes are unreachable from the source in the graph: ", - errors::FormatNodeNamesForError(unreachable_nodes)); - } - - // Builds Frames, indexed by name. - std::unordered_map frames; - for (Node* node : graph->op_nodes()) { - const ControlFlowInfo& cf = cf_info[node->id()]; - - VLOG(2) << "node: " << node->name() << " (" << node->id() - << ") frame_name: " << cf.frame_name - << " frame: " << (cf.frame ? cf.frame->name() : "---") - << " parent_frame: " - << (cf.parent_frame ? cf.parent_frame->name() : "---"); - TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); - - Frame& frame = frames[cf.frame_name]; - Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; - if (frame.parent == nullptr) { - frame.parent = parent; - frame.name = cf.frame_name; - ++parent->num_children; - } - - if (IsEnter(node)) { - Arg arg; - arg.enter = node; - TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", - &arg.is_loop_invariant)); - frame.args.push_back(arg); - } else if (IsLoopCond(node)) { - frame.loop_cond = node; - } - frame.nodes.insert(node); - } - - // Adds frames with no children (i.e., the innermost frames) to a worklist. - std::deque worklist; - for (auto& frame : frames) { - if (frame.second.num_children == 0) { - worklist.push_back(&frame.second); - } - } - - // Eliminate loops from innermost to outermost. - while (!worklist.empty()) { - Frame* frame = worklist.front(); - worklist.pop_front(); - if (frame->parent == frame) { - // Skip the root frame. - continue; - } - - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - - // If the parent has no remaining children, add it to the worklist. - --frame->parent->num_children; - if (frame->parent->num_children == 0) { - worklist.push_back(frame->parent); - } - } - // There should be no cycle at this point, since while loops have been removed - // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. - for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { - TF_RETURN_WITH_CONTEXT_IF_ERROR( - CheckNoCycleContains(node, graph->num_node_ids()), - "FunctionalizeLoop failed."); - } - } + // Functionalize and remove while loops from graph. + TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(lookup_library, graph, library)); // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled // in successive invocations. - TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); + TF_RETURN_IF_ERROR(FunctionalizeCond(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " << dump_graph::DumpGraphToFile("functionalize_final", *graph, library); + return Status::OK(); } +// Transformation that converts TensorFlow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index d941041d15532446d1413f16fe64602bfb1a7daa..55600f2a8b5302cef26b9be4ccd0f8804476a17a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -16,14 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. If lookup_library is provided, use -// it to make the library for control flow self-contained. +// operators and tf.cond() conditionals into function If operators, suitable for +// XLA compilation. If lookup_library is provided, use it to make the library +// for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index ccf249b35d66861888ad5e5e904b5f63b8ac50a1..c068a4110c0bb14282379eb7a3cbdae4e80ddbd6 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -37,12 +37,12 @@ limitations under the License. namespace tensorflow { namespace { -// Returns the names of the "then" and "else" functions for the XlaIf node in a +// Returns the names of the "then" and "else" functions for the If node in a // graph. Status FindIfThenAndElse(const GraphDef& graph, string* op_name, NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaIf") { + if (node.op() == "If") { *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); @@ -52,7 +52,7 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, return Status::OK(); } } - return errors::NotFound("No XlaIf node found in graph"); + return errors::NotFound("No If node found in graph"); } // Graph: @@ -115,8 +115,13 @@ TEST(FunctionalizeControlFlow, Conditional) { auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); + // TODO(jpienaar): Create wrapper for IfOp. + for (NodeDef& n : *expected.mutable_node()) { + if (n.op() == "XlaIf") n.set_op("If"); + } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -800,11 +805,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -818,7 +823,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ exit_j.output.op(), exit_k.output.op()}), identity_i, one_outer); auto next_iteration_i = @@ -924,7 +929,7 @@ TEST(FunctionalizeControlFlow, Complex) { scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); auto add_i = ops::Add(scope.WithOpName("outer/add") - .WithControlDependencies(gtl::ArraySlice{ + .WithControlDependencies(absl::Span{ while_op[0].op(), while_op[1].op()}), identity_i, one_outer); @@ -986,11 +991,11 @@ TEST(FunctionalizeControlFlow, Complex) { auto assign = ops::AssignAddVariableOp( scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); - auto one = - ops::Const(scope.WithOpName("outer/inner/One") - .WithControlDependencies( - gtl::ArraySlice{assign.operation}), - 1); + auto one = ops::Const( + scope.WithOpName("outer/inner/One") + .WithControlDependencies( + absl::Span{assign.operation}), + 1); auto add_j = ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); @@ -1013,63 +1018,5 @@ TEST(FunctionalizeControlFlow, Complex) { } } -TEST(FunctionalizeControlFlow, Cycle) { - std::unique_ptr graph(new Graph(OpRegistry::Global())); - // ----------------------------------------------------- - // | | - // | v - // less -> switch_1 --> add -> merge_1 -> identity -> switch_2 - // | ^ | - // | | v - // --------> one -------------------------> add_2 ---> merge_2 - { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); - auto two = - ops::Const(scope.WithOpName("cond/two") - .WithControlDependencies(switch_1.output_true), - 2); - auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"), - switch_1.output_true, two); - auto one = - ops::Const(scope.WithOpName("cond/one") - .WithControlDependencies(switch_1.output_false), - 1); - auto add = ops::Add(scope.WithOpName("cond/false/add"), - switch_1.output_false, one); - - auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"), - std::initializer_list{add, mul}); - auto identity = - ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output); - auto switch_2 = - ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less); - auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"), - switch_2.output_false, one); - auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"), - switch_2.output_true, two); - auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"), - std::initializer_list{add_2, mul_2}); - TF_ASSERT_OK(scope.ToGraph(graph.get())); - } - // No cycle before functionalize control flow. - TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph)); - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - // switch_1 and switch_2 have the same switch depth. They are replaced by a - // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle: - // less -> XlaIf <--> identity. - Status status = FunctionalizeControlFlow(graph.get(), &library); - EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detected a cycle")) - << status.error_message(); - EXPECT_TRUE( - str_util::StrContains(status.error_message(), "{{node cond/Less_5_If}}")) - << status.error_message(); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..924fcdd9cd72a6472e0b2748680f2552fa65ec79 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" + +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +bool NodeCmpByNameResourcesLast::operator()(const Node* lhs, + const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); +} + +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { + const char* const kRetValOp = "_Retval"; + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(strings::StrCat(kRetValOp, index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + return AddNodeDefToGraph(ret_def, graph); +} + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { + std::vector ready; + ready.push_back(node); + std::vector visited(num_nodes); + while (!ready.empty()) { + const Node* current_node = ready.back(); + ready.pop_back(); + visited[current_node->id()] = true; + for (const Edge* out : current_node->out_edges()) { + if (out->dst() == node) { + return errors::Internal("Detected a cycle: ", FormatNodeForError(*node), + " (", node->def().op(), ") feeds into itself."); + } else if (!visited[out->dst()->id()]) { + ready.push_back(out->dst()); + } + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h new file mode 100644 index 0000000000000000000000000000000000000000..61940e3586c59ffc660eaac8f8d035fbbbdfeffd --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ + +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" + +// Utility functions shared between functionalize cond and while. + +namespace tensorflow { + +// Check that the graph has no cycle containing the given node. +Status CheckNodeNotInCycle(const Node* node, const int num_nodes); + +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmpByNameResourcesLast { + bool operator()(const Node* lhs, const Node* rhs) const; +}; + +// Returns the Node* created from the NodeDef in the Graph. +xla::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); + +// Build a retval node of given type and index. +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); + +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return strings::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), + "}"); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e3c4b0e0f695f0073f2c8aa1a4b342e39ea4be5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -0,0 +1,668 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_while.h" + +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +using xla::StatusOr; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame* frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(5) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame->name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +StatusOr BuildArgNode(Graph* graph, DataType type, int index) { + const char* const kArgOp = "_Arg"; + NodeDef arg_def; + NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + return AddNodeDefToGraph(arg_def, graph); +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = absl::make_unique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = absl::make_unique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function referred by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { + continue; + } + TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + return NodeCmpByNameResourcesLast()(a.enter, b.enter); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + const Edge* enter_merge = nullptr; + for (const Edge* e : arg.enter->out_edges()) { + // Ignore control-edges to the sink node. These are allowed by the + // graph invariants, although probably they should have been stripped + // off earlier. + if (e->IsControlEdge() && e->dst()->IsSink()) { + continue; + } + if (enter_merge != nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has multiple successors: ", + FormatNodeForError(*enter_merge->dst()), + " and ", FormatNodeForError(*e->dst())); + } + enter_merge = e; + } + if (enter_merge == nullptr) { + return errors::Internal("Enter node for loop-varying argument ", + FormatNodeForError(*arg.enter), + " has zero successors"); + } + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + FormatNodeForError(*arg.merge), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + FormatNodeForError(*arg.merge), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + FormatNodeForError(*arg.next_iteration), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + FormatNodeForError(*arg.merge)); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + FormatNodeForError(*arg.merge)); + } + + // Update the device on the Identity outputs of the switch to match their + // target. These Identity outputs do not + + // Loop over the switch node's output to: + // - Find the Exit successor. + // - Set the sharding on all Identity outputs of the switch. These + // identity nodes are values used by the loop body or condition. + // The Identity node may have the wrong device so copy the device from + // one of its outputs instead. + std::deque possible_exit; + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument( + "Duplicate Exit successors to ", + FormatNodeForError(*arg.switch_node)); + } + arg.exit = edge->dst(); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + FormatNodeForError(*arg.switch_node), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } + } + } + } + } + + // Builds the condition and body functions. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph)); + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->nodes.clear(); + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); + + return Status::OK(); +} +} // namespace + +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + errors::FormatNodeNamesForError(unreachable_nodes)); + } + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + // There should be no cycle at this point, since while loops have been removed + // from graph. + // Check that the newly added XlaWhile nodes don't feed into themselves. + for (const Node* node : graph->op_nodes()) { + if (node->def().op() == "XlaWhile") { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckNodeNotInCycle(node, graph->num_node_ids()), + "Functionalizing loop failed."); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h new file mode 100644 index 0000000000000000000000000000000000000000..a708c6e4ec4e13527b4ee2d6c435dddee0a2b4e2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. +Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e4fdf0a6186eb69a2e3413838c91616b992ef2d6..1ed1fb3b021b27be00086b2e71cc9309e3d76049 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); for (int i = 0; i < args->size(); ++i) { @@ -145,6 +146,7 @@ Status GraphCompiler::Compile() { } OpKernelContext op_context(¶ms, n->num_outputs()); + VLOG(3) << "Translating " << params.op_kernel->name(); if (IsFunctional(n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b1366e9e31e28406c5bf1a808b9c5670558ed9c7..4c776fb1781e4d0b0d1fa5f313536eb42d6856bb 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -22,6 +22,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "broadcast_to_op.cc", "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", @@ -100,6 +101,12 @@ tf_kernel_library( "unary_ops.cc", "unpack_op.cc", "variable_ops.cc", + "xla_broadcast_helper_op.cc", + "xla_conv_op.cc", + "xla_dot_op.cc", + "xla_pad_op.cc", + "xla_reduce_op.cc", + "xla_select_and_scatter_op.cc", ], hdrs = [ "index_ops.h", @@ -108,6 +115,9 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 48f2a005ab16651fe29d0f6f9d881f95693da461..edced6bc0e57cfc2b1c62f1e4a010dd316f7d092 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -23,7 +23,7 @@ namespace { void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); const gtl::InlinedVector input_shape = @@ -34,7 +34,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ba3b1c9dab79a387c48e8e25e4804917f328f8a0..2e383b1473590403823863f89264e5381d8e8806 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); const int64 len = bcast.output_shape().size(); Tensor output(DT_INT32, TensorShape({len})); @@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx()); } diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2c328102e0bd84709707f102272691b6aec9a577..df17da4c1ca07053cf63757f1acf2b1a3735e705 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -30,21 +30,21 @@ namespace { // A subclass of a XlaBinaryOp must build the computation that // describes the (tensor,tensor)->tensor function to apply to each element of // the input. -#define XLA_MAKE_BINARY(NAME, HLO) \ - class NAME##Op : public XlaBinaryOp { \ - public: \ - explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::XlaOp Computation( \ - XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ - const gtl::ArraySlice& rhs_shape, \ - const BCast& broadcast_helper, \ - const std::vector& extend_dimensions) override { \ - xla::XlaBuilder* b = ctx->builder(); \ - (void)b; \ - return HLO; \ - } \ - }; \ +#define XLA_MAKE_BINARY(NAME, HLO) \ + class NAME##Op : public XlaBinaryOp { \ + public: \ + explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const absl::Span& lhs_shape, const xla::XlaOp& rhs, \ + const absl::Span& rhs_shape, \ + const BCast& broadcast_helper, \ + const std::vector& extend_dimensions) override { \ + xla::XlaBuilder* b = ctx->builder(); \ + (void)b; \ + return HLO; \ + } \ + }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op) XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions)); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4bd7c74dca2a7cbb51f2a329ac575d635f314516 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +class BroadcastToOp : public XlaOpKernel { + public: + explicit BroadcastToOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + + OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), + errors::InvalidArgument( + "Input rank (", input_shape.dims(), + ") must be less than or equal to the output rank (", + output_shape.dims(), ")")); + + auto input_dims = input_shape.dim_sizes(); + auto output_dims = output_shape.dim_sizes(); + + // Broadcasting is done right-to-left on right-aligned dimensions; reverse + // the two vectors so elements to be broadcast are aligned. + absl::c_reverse(input_dims); + absl::c_reverse(output_dims); + + std::vector broadcast_dims; + std::vector broadcast_shape; + for (int i = 0; i < output_shape.dims(); ++i) { + if (i < input_shape.dims()) { + OP_REQUIRES( + context, + (output_dims[i] == 0 && input_dims[i] == 0) || + (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), + errors::InvalidArgument("invalid shape to broadcast from ", + input_shape.DebugString(), " to ", + output_shape.DebugString())); + + broadcast_dims.push_back(broadcast_shape.size()); + if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { + broadcast_shape.push_back(output_dims[i]); + } + if (output_dims[i] != input_dims[i]) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(input_dims[i]); + broadcast_shape.push_back(output_dims[i] / input_dims[i]); + } + } else { + broadcast_shape.push_back(output_dims[i]); + } + } + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::Reshape( + xla::BroadcastInDim(context->Input(0), + xla::ShapeUtil::MakeShape( + context->input_xla_type(0), broadcast_shape), + broadcast_dims), + output_shape.dim_sizes()); + context->SetOutput(0, output); + } +}; + +REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"), + BroadcastToOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 5da7972397b32fb4a2f216913e065c04131a3773..674720e22fbf9d995e74c7dbd0ef7d7765941867 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -120,45 +120,30 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, {expanded_filter_shape.dims() - 2}); } -// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding -// zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter, - xla::XlaBuilder* builder) { - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dims() - 2; + int64 output_feature_dim = filter_shape.dims() - 1; + int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); + int64 input_feature = filter_shape.dim_size(input_feature_dim); // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 2, 1); - implicit_broadcast_filter_shape.set_dim( - implicit_broadcast_filter_shape.dims() - 1, - depthwise_multiplier * input_feature); - auto implicit_broadcast_filter = - xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); - - // Broadcast the filter to [H, W, ..., M, M*N]. - auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); - auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero); - - // If the filter mask is set, choose the broadcasted filter, othwerwise, - // choose zero. - return xla::Select(CreateExpandedFilterMask(filter_shape, builder), - expanded_filter, expanded_zero); + TensorShape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dim(output_feature_dim, + depthwise_multiplier * input_feature); + return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); } -// Inverse of ExpandFilterForDepthwiseConvolution. +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::XlaOp& filter_backprop, xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); auto masked_expanded_filter = xla::Select( CreateExpandedFilterMask(filter_shape, builder), filter_backprop, CreateExpandedZero(filter_shape, dtype, builder)); @@ -168,8 +153,7 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, // ExpandedZero guarantees that only one element is non zero, so there // cannot be accumulated precision error. xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), - {expanded_filter_shape.dims() - 2}), + *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), filter_shape.dim_sizes()); } @@ -245,15 +229,9 @@ class ConvOp : public XlaOpKernel { "input and filter must have the same depth: ", in_depth, " vs ", input_shape.dim_size(feature_dim))); - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp filter = ctx->Input(1); - TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(0), filter, b); - expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); } xla::ConvolutionDimensionNumbers dims; @@ -280,14 +258,15 @@ class ConvOp : public XlaOpKernel { int64 unused_output_size; OP_REQUIRES_OK( ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), + input_shape.dim_size(dim), filter_shape.dim_size(i), rhs_dilation[i], window_strides[i], padding_, &unused_output_size, &padding[i].first, &padding[i].second)); } - xla::XlaOp conv = - xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, - lhs_dilation, rhs_dilation, dims); + xla::XlaOp conv = xla::ConvGeneralDilated( + ctx->Input(0), filter, window_strides, padding, lhs_dilation, + rhs_dilation, dims, + /*feature_group_count=*/depthwise_ ? in_depth : 1); ctx->SetOutput(0, conv); } @@ -388,7 +367,6 @@ class ConvBackpropInputOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::XlaBuilder* b = ctx->builder(); auto filter = ctx->Input(1); auto out_backprop = ctx->Input(2); @@ -425,12 +403,6 @@ class ConvBackpropInputOp : public XlaOpKernel { rhs_dilation[i] = dilations_[dim]; } - // If this is a depthwise convolution, expand the filter. - if (depthwise_) { - filter = ExpandFilterForDepthwiseConvolution( - filter_shape, ctx->input_type(1), filter, b); - } - // Mirror the filter in the spatial dimensions. xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); @@ -438,7 +410,11 @@ class ConvBackpropInputOp : public XlaOpKernel { // = gradients (with padding and dilation) mirrored_weights xla::XlaOp in_backprop = xla::ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums); + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + depthwise_ ? out_backprop_shape.dim_size(feature_dim) / + filter_shape.dim_size(num_spatial_dims_ + 1) + : 1); ctx->SetOutput(0, in_backprop); } diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index a5b870f8dbf70bcee331992345d63fd5d986bdca..6653944a911588b7bc88d67b8cdd2c17850530f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -57,8 +57,8 @@ class XlaBinaryOp : public XlaOpKernel { // in the XLA documentation. virtual xla::XlaOp Computation( XlaOpKernelContext* ctx, const xla::XlaOp& lhs, - const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, - const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, + const absl::Span& lhs_shape, const xla::XlaOp& rhs, + const absl::Span& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ed44ad218b6dc073583ec339da082b6881ad672d..49c12fc232092873b69961644a059abc6035f64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -29,7 +29,7 @@ namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size, - gtl::ArraySlice other_dims, + absl::Span other_dims, xla::PrimitiveType element_type) { xla::XlaBuilder* builder = input.builder(); // Create two matrices that have the following forms, and compare them: @@ -177,8 +177,8 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); - tensorflow::gtl::ArraySlice other_dims(dims); - other_dims.pop_back(); + absl::Span other_dims(dims); + other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 35de96e0aab847fa39ef26d5f3052c392062fd7d..44140304fdf5cdf60d8ad8b85c532fcadff8ba86 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -95,11 +95,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // gather = s32[3,2] gather(operand, indices), - // output_window_dims={0}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3, 1} + // slice_sizes={3, 1} // // // Example of an N-D gather pulling out slices of shape [1,1,2] out of a @@ -108,42 +108,42 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, // operand = s32[3,3,2] parameter(0) // indices = s32[2,2] parameter(1) // gather = s32[2,2] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0,1}, - // gather_dims_to_operand_dims={0,1}, + // offset_dims={1}, + // collapsed_slice_dims={0,1}, + // start_index_map={0,1}, // index_vector_dim=0, - // window_bounds={1,1,2} + // slice_sizes={1,1,2} xla::GatherDimensionNumbers dim_numbers; - std::vector window_bounds; - window_bounds.reserve(input_shape.dims()); + std::vector slice_sizes; + slice_sizes.reserve(input_shape.dims()); for (int64 i = 0; i < input_shape.dims(); i++) { int64 window_bound; if (axis <= i && i < (axis + num_index_dims)) { - dim_numbers.add_elided_window_dims(i); + dim_numbers.add_collapsed_slice_dims(i); window_bound = 1; } else { window_bound = input_shape.dim_size(i); } - window_bounds.push_back(window_bound); + slice_sizes.push_back(window_bound); if (i < axis) { - dim_numbers.add_output_window_dims(i); + dim_numbers.add_offset_dims(i); } else if (i >= (axis + num_index_dims)) { int64 indices_rank = indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); - dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); + dim_numbers.add_offset_dims(i + indices_rank - num_index_dims); } } dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims()); for (int64 i = axis; i < axis + num_index_dims; i++) { - dim_numbers.add_gather_dims_to_operand_dims(i); + dim_numbers.add_start_index_map(i); } - *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds); + *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index e72200bfbcff20c55ac03030f1afc4bacaabf7ce..19dd38c46ef154ea74bcbb6721dd04924702efcc 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -25,7 +25,10 @@ class IdentityOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - ctx->SetOutput(i, ctx->Input(i)); + // Forwards using the underlying op_kernel_context so both tensor and + // resource values are forwarded correctly. + ctx->op_kernel_context()->set_output(i, + ctx->op_kernel_context()->input(i)); } } @@ -35,9 +38,10 @@ class IdentityOp : public XlaOpKernel { // XLA_* devices also register a "real" Identity operator so we suppress the // dummy operator using CompilationOnly(). -REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); - -REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(), + IdentityOp); +REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(), + IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6a7eb8d90c45ab119096eaa259e05c6ca768c5aa..6e1dbf5472f0b1eb0abcbe29c553ae926ecf2d8a 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -200,21 +200,10 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - bool resource_variable_seen = false; - for (int i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input_type(i) == DT_RESOURCE) { - resource_variable_seen = true; - } else { - OP_REQUIRES( - ctx, !resource_variable_seen, - errors::FailedPrecondition( - "Resource variables and regular inputs cannot be interleaved.")); - } - } - - xla::XlaOp outputs = xla::Conditional( - ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation, - xla::Tuple(b, inputs), *else_result.computation); + auto input_tuple = xla::Tuple(b, inputs); + xla::XlaOp outputs = + xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation, + input_tuple, *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8d75624e74028ea083c3facc4f9578ec14c50e6d..d9a0257b70bcf302dea77db2e9f7fa7b4543e038 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -32,13 +32,13 @@ namespace { // // 1. S := (N - 1) / gcd(N-1, R-1) // 2. k := (R - 1) / gcd(N-1, R-1) -// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1) // // For example, to Scale from 7x7 -> 15x15: // // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 -// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2) // // // The 7x7 -> 15x15 case is much too large to write out in full as an @@ -65,6 +65,8 @@ namespace { // 1/9 * 3 6 9 6 3 // 2 4 6 4 2 // 1 2 3 2 1 +// Note that the convolution kernel matrix is separable and thus we can instead +// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis. // Computes the size of the convolutional kernel and stride to use when resizing // from in_size to out_size. @@ -76,7 +78,8 @@ struct ResizeConvolutionDims { std::vector stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + absl::Span in_size, absl::Span out_size, + bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); ResizeConvolutionDims dims; @@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // entry before resizing. dims.stride[i] = dims.kernel_size[i] = 1; } else { - int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), - static_cast(out_size[i] - 1)); - dims.stride[i] = (in_size[i] - 1) / gcd; - dims.kernel_size[i] = (out_size[i] - 1) / gcd; + // The scaling factor changes depending on the alignment of corners. + const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i]; + const int64 out_size_factor = + align_corners ? out_size[i] - 1 : out_size[i]; + + int64 gcd = MathUtil::GCD(static_cast(in_size_factor), + static_cast(out_size_factor)); + dims.stride[i] = in_size_factor / gcd; + dims.kernel_size[i] = out_size_factor / gcd; } } return dims; } +// The upper padding of the input needed by ConvGeneralDilated calls is +// determined by solving two related relationships (assuming rhs_dilation == 0): +// 1. dilated_input_dim = lower_padding + upper_padding +// + lhs_dilation * (in_size - 1) + 1 +// 2. dilated_input_dim = (2 * dims.kernel-size - 1) +// + dims.stride * (out_size - 1) +int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, + int64 stride) { + return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - + 1 - (kernel_size * (in_size - 1)); +} + // Form a 2D convolution kernel like: // 1 2 3 2 1 // 2 4 6 4 2 @@ -127,7 +147,7 @@ std::vector Make1DKernel(int64 n) { const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels) { xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); @@ -145,7 +165,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, - gtl::ArraySlice kernel_size, + absl::Span kernel_size, int64 channels, int64 dim) { xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); @@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector out_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, out_size); + ComputeResizeConvolutionParameters(in_size, out_size, align_corners); xla::XlaOp output; - // Split convolutions into independent dimensions if they wmuld be a very + + // Concatenation and padding below currently assumes num_spatial_dims is 2 to + // prevent needless code complexity. + CHECK_EQ(num_spatial_dims, 2) + << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently."; + std::vector upper_padding(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + upper_padding[i] = dims.kernel_size[i] - 1; + } + xla::XlaOp input_data = input; + + if (!align_corners) { + // When Tensorflow does not align_corners, the resize indexing can access + // beyond the upper bound and is instead clamped to prevent out of bounds + // reads. This is conceptually the same as extending the edges of the input. + // We emulate this by copying the last row/column of the input. + // Calculate what padding would be needed then determine how far to extend + // the border before lhs dilation. + std::vector num_extended(num_spatial_dims); + upper_padding[0] = CalculateUpperPadding( + in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = CalculateUpperPadding( + in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]); + num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); + num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + + if (num_extended[0] > 0) { + auto slice = + xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, + {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + for (int i = 0; i < num_extended[0]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); + } + } + + if (num_extended[1] > 0) { + auto slice = + xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, + {1, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); + for (int i = 0; i < num_extended[1]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); + } + } + + // Setting in_size to (in_size + num_extended) due to the above Slice and + // ConcatInDim. Recalculate needed padding after the above Slice/Concat. + upper_padding[0] = + CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0], + dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = + CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1], + dims.kernel_size[1], dims.stride[1]); + } + + // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = xla::ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = + xla::ConvGeneralDilated(input_data, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, upper_padding[0]}, + {dims.kernel_size[1] - 1, upper_padding[1]}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); output = xla::ConvGeneralDilated( - input, kernel0, {dims.stride[0], 1}, + input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = @@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ - {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, /*rhs_dilation=*/{1, 1}, dimension_numbers); } @@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector grad_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, grad_size); + ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); // To form the backward convolution, we keep the kernel unchanged (it is // already symmetric) and swap the roles of strides and LHS dilation. @@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel { public: explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); - OP_REQUIRES( - ctx, align_corners_ == true, - errors::Unimplemented( - "ResizeBilinear with align_corners=False is not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel { // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. - std::vector slice_size = in_size; bool slice_input = false; for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] > 1 && out_size[i] == 1) { // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first // entry before resizing. slice_input = true; - slice_size[i] = 1; + in_size[i] = 1; } } if (slice_input) { - input = xla::Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = + xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } // Output is always type float. @@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel { // operations along different dimensions. // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. // // This makes the convolutions kernels smaller and the operation faster. xla::XlaOp output = input; @@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel { (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1) { + k[0] > 1 && k[1] > 1 && align_corners_) { std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, next_out_size, + channels, align_corners_); input = output; in_size = next_out_size; } else { - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, + channels, align_corners_); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels); + in_size, out_size, channels, + align_corners_); in_size = out_size; } } @@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels); + b, grad, num_spatial_dims, in_size, next_grad_size, channels, + align_corners_); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2..2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, const xla::LiteralSlice& pad_literal, + const MirrorPadMode mode, xla::XlaBuilder* b) { + // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT + // will not mirror the border values while symmetric does. + // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is: + // - [1, 2, 3, 2, 1] in reflect mode + // - [1, 2, 3, 3, 2] in symmetric mode. + int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { @@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel { TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + + // Padding amounts on each side must be no more than the size of the + // original shape. + TF_RET_CHECK(lhs_padding >= 0 && + lhs_padding <= dim_size - excluded_edges); + TF_RET_CHECK(rhs_padding >= 0 && + rhs_padding <= dim_size - excluded_edges); + + auto lhs_pad = + xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding, + dim_size - excluded_edges, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges, + excluded_edges + rhs_padding, 1, dimno); accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; @@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel { MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); - OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT, - xla::Unimplemented( - "Only REFLECT MirrorPad mode is currently supported")); + OP_REQUIRES( + ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC, + xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and " + "REFLECT modes are currently supported")); const int dims = input_shape.dims(); OP_REQUIRES( @@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = - DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4d180aff806f12875f0e43f111ee090f6607ef6..f6f158a73be42ea2602811ad64a2a2c655dab088 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Divide each element of an image by the count of elements that contributed to -// that element during pooling. -static xla::XlaOp AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape, xla::Padding padding, - const std::vector& ksize, const std::vector& stride, - int num_spatial_dims, TensorFormat data_format) { - if (padding == xla::Padding::kValid) { - // In VALID padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, - [](int64 a, int64 b) { return a * b; }); - - auto divisor = - XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return xla::Div(output, divisor); - } else { - // For SAME padding, the padding shouldn't be included in the - // counts. We use another ReduceWindow to find the right counts. - - // TODO(phawkins): use a less brute-force way to compute this. Only - // the boundary regions will have interesting values here. - - std::vector input_dim_sizes(num_spatial_dims); - std::vector window_dims(num_spatial_dims); - std::vector window_ksize(num_spatial_dims); - std::vector window_stride(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); - input_dim_sizes[i] = input_shape.dim_size(dim); - window_dims[i] = dim; - window_ksize[i] = ksize[dim]; - window_stride[i] = stride[dim]; - } - - // Build a matrix of all 1s, with the same width/height as the input. - const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = xla::Broadcast( - XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); - - // Perform a ReduceWindow with the same window size, strides, and padding - // to count the number of contributions to each result element. - auto reduce = xla::ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, - xla::Padding::kSame); - auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - - return xla::Div(output, counts, window_dims); - } -} - class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) @@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("out_backprop must be ", num_dims(), "-dimensional")); - int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - int64 depth = out_backprop_shape.dim_size(depth_dim); - - // We can think of average-pooling as: - // * a convolution with a kernel consisting entirely of 1s, where the - // input feature and output feature are equal, and 0s everywhere else. - // * followed by dividing by the counts. - // - // This then gives us an algorithm to build the gradient: - // * divide out_backprop by the counts, followed by - // * Conv2DBackpropInput specialized for that kernel, which simplifies to - // a Pad and a ReduceWindow. - // - // For an explanation of backpropagation for convolution, see the comments - // in third_party/tensorflow/core/kernels/conv_grad_ops.h - - // TF filter shape is [ H, W, ..., inC, outC ] - std::vector filter_dims(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - filter_dims[i] = ksize_[dim]; - } - filter_dims[num_dims() - 2] = depth; - filter_dims[num_dims() - 1] = depth; - TensorShape filter_shape(filter_dims); - - // Reuse the logic from Conv2DBackpropInput to compute padding. - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, ConvBackpropComputeDimensions( - type_string(), /*num_spatial_dims=*/num_spatial_dims_, - gradients_shape, filter_shape, out_backprop_shape, stride_, - padding_, data_format_, &dims)); - - // The input gradients are computed by a convolution of the output gradients - // and the filter, with some appropriate padding. See the comment at the top - // of conv_grad_ops.h for details. - xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - auto dtype = input_type(1); + std::vector stride_int64s(stride_.begin(), stride_.end()); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - // Divide the out_backprop values by the counts for each spatial position. - std::vector stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = AvgPoolDivideByCount( - ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, - stride_int64s, num_spatial_dims_, data_format_); - - // Pad the gradients in the spatial dimensions. We use the same padding - // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - auto* padding = padding_config.mutable_dimensions(dim); - padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); - padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); - padding->set_interior_padding(dims.spatial_dims[i].stride - 1); - } - - auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); - - // in_backprop = padded_gradients ones - std::vector ones(num_dims(), 1LL); - auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = xla::ReduceWindow( - XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), ksize_, - /* window_strides=*/ones, xla::Padding::kValid); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); + xla::PrimitiveType xla_reduction_type; + auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); + auto converted_out_backprop = + xla::ConvertElementType(out_backprop, xla_reduction_type); + auto xla_data_format = + XlaTensorFormat(data_format_, gradients_shape.dims() - 2); + auto padding_values = + MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, + xla_padding, xla_data_format); + auto in_backprop = + xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), + ksize_, stride_int64s, padding_values, xla_data_format, + /*counts_include_padding=*/padding_ == VALID); + // Convert the pooling result back to the input type before returning it. + xla::PrimitiveType xla_out_backprop_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), + &xla_out_backprop_type)); + ctx->SetOutput(0, + xla::ConvertElementType(in_backprop, xla_out_backprop_type)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index de9068a640dc03b141b6954eaa1629dd6c8c1f3a..7ea0afc1f53cbe4cfcc3f6121a4ecd55864c1b52 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -23,15 +23,10 @@ namespace { class QROp : public XlaOpKernel { public: explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - bool full_matrices; - OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices)); - OP_REQUIRES( - ctx, full_matrices, - errors::Unimplemented("full_matrices=False case of QR decomposition is " - "not implemented in TF/XLA")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); } void Compile(XlaOpKernelContext* ctx) override { - auto result = QRDecomposition(ctx->Input(0)); + auto result = QRDecomposition(ctx->Input(0), full_matrices_); if (!result.ok()) { ctx->SetStatus(result.status()); return; @@ -39,6 +34,11 @@ class QROp : public XlaOpKernel { ctx->SetOutput(0, result.ValueOrDie().q); ctx->SetOutput(1, result.ValueOrDie().r); } + + private: + // If true, compute full-sized q and r. If false, compute only the leading P + // columns of q. + bool full_matrices_; }; REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2da9340625db08b14b78340c471f096baf15689d..afd5986846705f66eb4c7ced9dbe2f4757f5af7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -155,7 +155,8 @@ class RandomShuffleOp : public XlaOpKernel { xla::XlaOp indices = xla::Iota(builder, xla::S32, n); // Swap the indices at i and swaps[i]. - auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto swap_body_fn = [&](xla::XlaOp i, + absl::Span loop_vars, xla::XlaBuilder* builder) -> xla::StatusOr> { auto swaps = loop_vars[0]; diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index b11a4ce36da9907ce8fe377c075023a4540797fa..8102faad28db71075fb8da269c55edbdb667193e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel { explicit ReduceWindowOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_dimensions", &window_dimensions_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_strides", &window_strides_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); const DataType dtype = context->input_type(0); + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + const int rank = input_shape.dims(); - OP_REQUIRES(context, rank == window_dimensions_.size(), + OP_REQUIRES(context, rank == window_dimensions.size(), errors::InvalidArgument( "The size of window_dimensions must be equal to the input " "rank (", - window_dimensions_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == window_strides_.size(), + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), errors::InvalidArgument( "The size of window_strides must be equal to the input " "rank (", - window_strides_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_low_.size(), - errors::InvalidArgument( - "The size of padding_low must be equal to the input " - "rank (", - padding_low_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_high_.size(), - errors::InvalidArgument( - "The size of padding_high must be equal to the input " - "rank (", - padding_high_.size(), " vs. ", rank, ")")); - - xla::XlaBuilder* builder = context->builder(); + window_strides.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel { compile_options.use_tuple_arg = false; compile_options.resolve_compile_time_constants = false; compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; XlaCompiler::CompilationResult reducer; OP_REQUIRES_OK(context, context->compiler()->CompileFunction( compile_options, *computation_, @@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel { xla::Shape scalar_shape; OP_REQUIRES_OK(context, TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); OP_REQUIRES(context, - xla::ShapeUtil::Compatible( - reducer.xla_output_shape, - xla::ShapeUtil::MakeTupleShape({scalar_shape})), + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, errors::InvalidArgument( - "Invalid output shape of ReduceWindow reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", - xla::ShapeUtil::HumanString(reducer.xla_output_shape))); - - // Wraps the reducer in a computation that unpacks the output tuple. - xla::XlaComputation wrapper; - { - std::unique_ptr cb = - builder->CreateSubBuilder("wrapper"); - auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); - auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); - auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); - OP_REQUIRES_OK(context, result.status()); - wrapper = std::move(result.ValueOrDie()); - } - - std::vector> padding(rank); - for (int i = 0; i < rank; ++i) { - padding[i] = {padding_low_[i], padding_high_[i]}; + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; } xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( - context->Input(0), context->Input(1), wrapper, window_dimensions_, - window_strides_, padding); + context->Input(0), context->Input(1), *reducer.computation, + window_dimensions, window_strides, padding); context->SetOutput(0, output); } private: const NameAttrList* computation_; - std::vector window_dimensions_; - std::vector window_strides_; - std::vector padding_low_; - std::vector padding_high_; TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); }; -REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); +REGISTER_XLA_OP(Name("XlaReduceWindow") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + ReduceWindowOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index b52f0a0ab6290f2019bb58120be5c2364ec15bb6..598248563bb93146e6dea3016822d26b8bf368e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -29,9 +30,6 @@ namespace tensorflow { XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type) : XlaOpKernel(ctx), reduction_type_(reduction_type) { - const DataType dt = BaseType(input_type(0)); - OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); - OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); OP_REQUIRES_OK( ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); @@ -58,20 +56,24 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { return; } + OP_REQUIRES(ctx, axes_tensor_shape.dims() <= 1, + errors::InvalidArgument( + "Expected scalar or vector as index argument, got ", + axes_tensor_shape.DebugString())); + // Evaluate the constant, reshaping to a 1-vector if it is a scalar. + std::vector axes; xla::Literal axes_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, - &axes_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << axes_literal.ToString(); + VLOG(1) << "axes : " << absl::StrJoin(axes, ","); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = axes_literal.Get({i}); + int64 index = axes[i]; OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 121750a82a8c5cbe940068555ad273b7e0d22dfc..366ce42866e9f1375ee0ff6f4985c8f461fc0885 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel { sizes_shape.DebugString())); const int64 num_dims = sizes_shape.num_elements(); - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + std::vector shape_input; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one if there @@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = literal.Get({d}); + const int32 size = shape_input[d]; if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 64900e4709fd3e16d21096b0cfff8922906cb0d4..e172c649325adb6f7761ce0be141f21e8d545bc1 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel { } else { xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + DataType input_type = ctx->input_type(0); + XlaContext& tc = XlaContext::Get(ctx); + + if (input_type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + ctx->SetStatus(tc.AddResourceRetval(index_, resource)); + return; + } auto is_constant = ctx->builder()->IsConstant(input); if (!is_constant.ok()) { @@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel { return; } - XlaContext& tc = XlaContext::Get(ctx); if (tc.resolve_compile_time_constants() && (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; @@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index d962ef4a5f53470838643541f8a1e693d2f4011c..c0afccaa5b15dd33fcd016dfdd9bb18e244bf90a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -95,10 +95,24 @@ class ReverseV2Op : public XlaOpKernel { std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); + // witnessed_axes is used to ensure that the same axis is not marked to be + // reversed multiple times. + gtl::InlinedVector witnessed_axes(x_shape.dims(), false); + for (int d = 0; d < axes.size(); ++d) { - OP_REQUIRES(ctx, (0 <= axes[d]) && (axes[d] < x_shape.dims()), - errors::InvalidArgument(axes[d], " is out of range [0, ", - x_shape.dims(), ").")); + OP_REQUIRES( + ctx, (-x_shape.dims() <= axes[d]) && (axes[d] < x_shape.dims()), + errors::InvalidArgument(axes[d], " is out of range [-", + x_shape.dims(), ", ", x_shape.dims(), ").")); + // Axes can be negative and are shifted to the canonical index before + // being lowered to HLO. + if (axes[d] < 0) { + axes[d] += x_shape.dims(); + } + OP_REQUIRES(ctx, !witnessed_axes[axes[d]], + errors::InvalidArgument("canonicalized axis ", axes[d], + " was repeated.")); + witnessed_axes[axes[d]] = true; } ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f..9e4c57c9bf73369662274f6b783418e18ff860c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -66,8 +66,8 @@ class SelectOp : public XlaOpKernel { // XLA. It seems we have to broadcast on the left and then Reshape // to get the dimensions in the right order. const auto dim_sizes = then_shape.dim_sizes(); - gtl::ArraySlice bdims = dim_sizes; - bdims.pop_front(); + absl::Span bdims = dim_sizes; + bdims.remove_prefix(1); cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 6adc3c58de63ee70c26bed47eebef955893df4a5..537b71f3c0cf3622a8a45a717ac406da69f5c3c7 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Slice Op. +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 025ba827410f1a9f993a8a1855558a2daa86609b..d6bd927135c013ac1ec3f6547aef358dc2741896 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -33,7 +33,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = str_util::StartsWith(type_string(), "Log"); + log_ = absl::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 7327258c31f21f45ff7ffffbc9db7a2a70b4a14c..b7b4f3a5465c8eea832ef940b7c84a7435edc38c 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -23,7 +23,7 @@ namespace { void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, DataType input_dtype, const TensorShape& input_tensor_shape, - gtl::ArraySlice block_shape, + absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); const gtl::InlinedVector input_shape = @@ -34,7 +34,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, ctx, input_rank >= 1 + block_rank, errors::InvalidArgument("input rank should be >= ", 1 + block_rank, " instead of ", input_rank)); - gtl::ArraySlice remainder_shape(input_shape); + absl::Span remainder_shape(input_shape); remainder_shape.remove_prefix(1 + block_rank); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 1062399d91bd9a9bf8c3820c5ecac534c110746d..472d4744d7d9cec65645c3259b0c097f0c756bac 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/util/strided_slice_op.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mem.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index be1814d8e3ae2c0ddad0134b9288e0ea084aa81b..bb114d1aedd57c7de992a05b37ad53443489596f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -122,7 +122,7 @@ Status GetTensorArrayShape(const XlaResource* resource, // relevant slice of 'operand'. xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, const xla::XlaOp& update, - const gtl::ArraySlice& update_dims, + absl::Span update_dims, const xla::XlaOp& start_indices) { xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims); xla::XlaOp sum = xla::Add(current, update); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 1233a37565d3a40c6dd2882b3139dedbf690a7b6..93d5996b5eaf10221b1d7067e7650b78cd6b8fef 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Tile Op. #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -70,7 +70,7 @@ class TileOp : public XlaOpKernel { bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { int multiple = literal.Get({i}); - OP_REQUIRES(ctx, multiple, + OP_REQUIRES(ctx, multiple >= 0, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); int64 new_dim = input_shape.dim_size(i) * multiple; diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index be5e91138656716daddcc3c7a68dbb78ecb69103..7077c2e3a546e198bdb4ff944ea531f3158810f2 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -688,7 +688,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, } // grad_to_use = grad + 2 * l2_shrinkage * var - // new_accum = accum + grad_to_use * grad_to_use + // new_accum = accum + grad * grad // linear += grad_to_use - // (new_accum^(-lr_power) - accum^(-lr_power)) / lr * var // quadratic = (new_accum^(-lr_power) / lr) + 2 * l2 @@ -704,7 +704,7 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, grad_to_use = grad; } - xla::XlaOp new_accum = accum + xla::Square(grad_to_use); + xla::XlaOp new_accum = accum + xla::Square(grad); xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power); xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power); linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..412afeaaad96842521fbd306f5b666e837e675fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class XlaBroadcastHelperOp : public XlaOpKernel { + public: + explicit XlaBroadcastHelperOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp lhs = context->Input(0); + xla::XlaOp rhs = context->Input(1); + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); + const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; + const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; + + std::vector broadcast_dims; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims", + &broadcast_dims)); + if (broadcast_dims.empty()) { + OP_REQUIRES( + context, + lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || + rhs_shape.dims() == 0, + errors::InvalidArgument( + "If broadcast_dims is empty, both " + "arguments must have equal rank; " + "argument shapes, or at least one argument must be a scalar: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + return; + } + + OP_REQUIRES( + context, broadcast_dims.size() == min_rank_shape->dims(), + errors::InvalidArgument( + "broadcast_dims must have size equal to the smaller argument rank; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + std::vector sorted_broadcast_dims = broadcast_dims; + absl::c_sort(sorted_broadcast_dims); + std::set dims_set(broadcast_dims.begin(), broadcast_dims.end()); + OP_REQUIRES(context, + dims_set.size() == broadcast_dims.size() && + broadcast_dims == sorted_broadcast_dims, + errors::InvalidArgument( + "Duplicate or nonmonotonic dimension in broadcast_dims; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]")); + + std::vector broadcast_shape(max_rank_shape->dims(), 1LL); + for (int i = 0; i < broadcast_dims.size(); ++i) { + const int dim = broadcast_dims[i]; + OP_REQUIRES( + context, dim >= 0 && dim < broadcast_shape.size(), + errors::InvalidArgument( + "Invalid broadcast dimension (", dim, "); broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + broadcast_shape[dim] = min_rank_shape->dim_size(i); + } + xla::PrimitiveType type = context->input_xla_type(0); + xla::Shape broadcast_xla_shape = + xla::ShapeUtil::MakeShape(type, broadcast_shape); + if (broadcast_lhs) { + lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + } else { + rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + } + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + } + + private: + xla::DotDimensionNumbers dnums_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp); +}; + +REGISTER_XLA_OP( + Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + XlaBroadcastHelperOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8848623868091f8d19b1622f23ba23c68689d90d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaConvOp : public XlaOpKernel { + public: + explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + const TensorShape padding_shape = context->InputShape("padding"); + std::vector window_strides; + std::vector lhs_dilation; + std::vector rhs_dilation; + int64 feature_group_count; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation", + &lhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation", + &rhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar( + "feature_group_count", &feature_group_count)); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::ConvGeneralDilated( + context->Input(0), context->Input(1), window_strides, padding, + lhs_dilation, rhs_dilation, dnums_, feature_group_count, + &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::ConvolutionDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); +}; + +REGISTER_XLA_OP(Name("XlaConv") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("lhs_dilation") + .CompileTimeConstInput("rhs_dilation") + .CompileTimeConstInput("feature_group_count") + .CompileTimeConstInput("padding"), + XlaConvOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2fed53e5c072e1a50e0f07f45357ee86c90f986f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDotOp : public XlaOpKernel { + public: + explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), + dnums_, &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::DotDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); +}; + +REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59502d83c7338bd1b05b3323a97761fff2da186a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaPadOp : public XlaOpKernel { + public: + explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape padding_value_shape = + context->InputShape("padding_value"); + + std::vector padding_low; + std::vector padding_high; + std::vector padding_interior; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low", + &padding_low)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high", + &padding_high)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "padding_interior", &padding_interior)); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape), + errors::InvalidArgument("padding_value must be a scalar")); + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == padding_low.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_interior.size(), + errors::InvalidArgument( + "The size of padding_interior must be equal to the input " + "rank (", + padding_interior.size(), " vs. ", rank, ")")); + + auto non_negative = [](int64 x) { return x >= 0; }; + OP_REQUIRES( + context, absl::c_all_of(padding_low, non_negative), + errors::InvalidArgument("padding_low must be non-negative, got [", + absl::StrJoin(padding_low, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_high, non_negative), + errors::InvalidArgument("padding_high must be non-negative, got [", + absl::StrJoin(padding_high, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_interior, non_negative), + errors::InvalidArgument("padding_interior must be non-negative, got [", + absl::StrJoin(padding_interior, ","), "]")); + + xla::PaddingConfig padding_config; + for (int i = 0; i < rank; ++i) { + auto* dim = padding_config.add_dimensions(); + dim->set_edge_padding_low(padding_low[i]); + dim->set_edge_padding_high(padding_high[i]); + dim->set_interior_padding(padding_interior[i]); + } + + xla::XlaOp output = + xla::Pad(context->Input("input"), context->Input("padding_value"), + padding_config); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp); +}; + +REGISTER_XLA_OP(Name("XlaPad") + .CompileTimeConstInput("padding_low") + .CompileTimeConstInput("padding_high") + .CompileTimeConstInput("padding_interior"), + XlaPadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc2425f37bfa793ce3a106b635c9dffd15b975ff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaReduceOp : public XlaOpKernel { + public: + explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_)); + OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce", + &dimensions_to_reduce_)); + std::set dims_set(dimensions_to_reduce_.begin(), + dimensions_to_reduce_.end()); + OP_REQUIRES( + context, dims_set.size() == dimensions_to_reduce_.size(), + errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " + "argument to XlaReduce")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape init_value_shape = context->InputShape("init_value"); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), + errors::InvalidArgument("init_value must be a scalar")); + + auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; + OP_REQUIRES(context, + rank >= dimensions_to_reduce_.size() && + absl::c_all_of(dimensions_to_reduce_, dim_in_range), + errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce")); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *reducer_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of XlaReduce reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + xla::XlaOp output = + xla::Reduce(context->Input("input"), context->Input("init_value"), + *reducer.computation, dimensions_to_reduce_); + context->SetOutput(0, output); + } + + private: + const NameAttrList* reducer_; + std::vector dimensions_to_reduce_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); +}; + +REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..089776fcf74fcf6b363dfff5de8d86d7449eacd6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSelectAndScatterOp : public XlaOpKernel { + public: + explicit XlaSelectAndScatterOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); + OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides.size(), " vs. ", rank, ")")); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; + + // Build the select function. + XlaCompiler::Argument select_arg; + select_arg.kind = XlaCompiler::Argument::kParameter; + select_arg.type = dtype; + select_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult select; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *select_computation_, + {select_arg, select_arg}, &select)); + + xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(select.xla_output_shape, + select_output_shape), + errors::InvalidArgument( + "Invalid output shape of XlaSelectAndScatter select. Expected ", + xla::ShapeUtil::HumanString(select_output_shape), " got ", + xla::ShapeUtil::HumanString(select.xla_output_shape))); + + // Build the scatter function. + XlaCompiler::Argument scatter_arg; + scatter_arg.kind = XlaCompiler::Argument::kParameter; + scatter_arg.type = dtype; + scatter_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult scatter; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *scatter_computation_, + {scatter_arg, scatter_arg}, &scatter)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of scatter. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(scatter.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( + context->Input("operand"), *select.computation, window_dimensions, + window_strides, padding, context->Input("source"), + context->Input("init_value"), *scatter.computation); + context->SetOutput(0, output); + } + + private: + const NameAttrList* select_computation_; + const NameAttrList* scatter_computation_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); +}; + +REGISTER_XLA_OP(Name("XlaSelectAndScatter") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + XlaSelectAndScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index cb7a40e23d539f758d963791f1c2b4d37374ade5..9365d203f06d9f1cad320353f43db010d39697af 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:lib", ], @@ -78,8 +78,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", @@ -104,6 +104,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -119,6 +120,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", @@ -165,6 +167,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -203,5 +206,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f666d22ea44216beef74608bb4d9f33fb2fe82c6..d8c050d09e871c80e128989c9fbdb57c266b19ed 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y) { + bool transpose_y, bool conjugate_x, bool conjugate_y, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + // If there are no batch dimensions, use a regular Dot. // TODO(b/69062148) Remove this code when Dot emitters can be passed // dimensions to transpose directly (i.e. without requiring a Transpose @@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, if (batch_dimension_numbers.empty()) { auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); + return xla::Dot(lhs, rhs, &precision_proto); } xla::DotDimensionNumbers dot_dnums; @@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return xla::DotGeneral(x, y, dot_dnums); + + return xla::DotGeneral(x, y, dot_dnums, &precision_proto); }); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76..6cfccd55530ff40a309673d57d1fe61fc8264316 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -45,7 +45,9 @@ namespace tensorflow { // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false); + bool conjugate_y = false, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117..c50a8de33e93a91b1a414146147de48df603eb85 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -49,20 +49,22 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { +xla::XlaOp CholeskyUnblocked(xla::XlaOp a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int n_dims = xla::ShapeUtil::Rank(a_shape); const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); - gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - 2); + auto major_dims = xla::AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - 2); xla::XlaOp l = xla::ZerosLike(a); // Construct the for loop body to iterate over rows. - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::Shape col_shape; @@ -101,7 +103,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -121,7 +124,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // r.T) auto dot = BatchDot(body_l, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -145,7 +149,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -181,14 +186,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); - auto factorized = CholeskyUnblocked(x); + auto factorized = CholeskyUnblocked(x, precision); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 1bef9bb166c576ec665bb48265b4da200ddca2a0..60cd7ded53fe862f29ca2bb68b175fcd1c89b70c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -30,7 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index fc0c1ee838190b1f1a7ca5b901c97e0a35232a97..0a140fa93caec28ebbbd666fd4fa518222ea23a4 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -65,9 +65,9 @@ namespace { // return (v, tau, beta) // TODO(phawkins): LAPACK's xLARFG implementation has code for handling // overflows in the norm/beta calculations. Perhaps do the same here. -xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice batch_dims, - const int64 m, xla::XlaOp* v, xla::XlaOp* tau, - xla::XlaOp* beta) { +xla::Status House(xla::XlaOp x, xla::XlaOp k, + absl::Span batch_dims, const int64 m, + xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { xla::XlaBuilder* const builder = x.builder(); TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); const xla::PrimitiveType type = x_shape.element_type(); @@ -149,7 +149,8 @@ struct QRBlockResult { xla::XlaOp taus; // Shape: [..., n] xla::XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock(xla::XlaOp a) { +xla::StatusOr QRBlock( + xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -172,7 +173,7 @@ xla::StatusOr QRBlock(xla::XlaOp a) { std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); auto qr_body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto a = values[0]; auto vs = values[1]; @@ -190,8 +191,12 @@ xla::StatusOr QRBlock(xla::XlaOp a) { auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a); - vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + auto vva = + BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + vva = + BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -250,14 +255,15 @@ xla::StatusOr QRBlock(xla::XlaOp a) { // There is no need to return Y since at termination of the loop it is equal to // vs. xla::StatusOr ComputeWYRepresentation( - xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n) { + xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, + xla::XlaOp taus, int64 m, int64 n, + xla::PrecisionConfigProto::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; auto body_fn = - [&](xla::XlaOp j, gtl::ArraySlice values, + [&](xla::XlaOp j, absl::Span values, xla::XlaBuilder* builder) -> xla::StatusOr> { auto w = values[0]; auto y = values[1]; @@ -272,9 +278,12 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true); + auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv); + auto wyv = + BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); auto z = xla::Mul( -beta, v + wyv, @@ -321,8 +330,9 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size) { +xla::StatusOr QRDecomposition( + xla::XlaOp a, bool full_matrices, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -352,33 +362,47 @@ xla::StatusOr QRDecomposition(xla::XlaOp a, int64 k = std::min(block_size, p - i); auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); // Compute the I-WY block representation of a product of Householder // matrices. - TF_ASSIGN_OR_RETURN(auto w, - ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k)); + TF_ASSIGN_OR_RETURN( + auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k, precision)); auto y = qr_block.vs; // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); - a_update = BatchDot(y, a_update); + auto a_update = + BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + a_update = + BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w); - q_update = - BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + auto q_update = + BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + q_update = BatchDot(q_update, y, /*transpose_x=*/false, + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } QRDecompositionResult result; + + // full_matrices is false when only a partial result in needed. Slice to the + // needed dimensions here. + if (!full_matrices) { + q = SliceInMinorDims(q, {0, 0}, {m, p}); + a = SliceInMinorDims(a, {0, 0}, {p, n}); + } result.q = q; result.r = a; return result; diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index abd2316ac961f583dd29f90f43cf6209de30bd6a..8a389fb7b053257adcd2a338dca52445c78381d1 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -32,8 +33,10 @@ struct QRDecompositionResult { xla::XlaOp r; }; -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size = 128); +xla::StatusOr QRDecomposition( + xla::XlaOp a, bool full_matrices, int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index ba22eff73abab11abeb57283c63318b2e50a9ca1..38dfde165df47ca78a25a068a901cd1071aa55e2 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -40,9 +40,9 @@ xla::StatusOr XlaScatter( TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); - gtl::ArraySlice indices_dims = + absl::Span indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - gtl::ArraySlice buffer_dims = + absl::Span buffer_dims = xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains @@ -58,7 +58,7 @@ xla::StatusOr XlaScatter( ") must be <= the rank of the buffer (shape: ", xla::ShapeUtil::HumanString(buffer_shape), ")"); } - indices_dims.pop_back(); + indices_dims.remove_suffix(1); } int64 num_indices = 1; @@ -107,7 +107,7 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + auto body_fn = [&](xla::XlaOp i, absl::Span loop_vars, xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 04fa10108cef66f429392951eea70e59643a2d29..37b2240b45b4ae6a587c827cfdfa1096b4e1737e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -57,7 +57,7 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // We can grab entire blocks using gather if (n > block_size) { // Construct the starting indices of the diagonal blocks - auto gather_indices = + auto start_indices = Transpose(Broadcast(Mul(Iota(builder, xla::S32, num_blocks), xla::ConstantR0(builder, block_size)), /*broadcast_sizes=*/{2}), @@ -65,13 +65,13 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { // Gather the diagonal blocks xla::GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(ndims - 1); - dim_numbers.add_output_window_dims(ndims); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 2); - dim_numbers.add_gather_dims_to_operand_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims - 1); + dim_numbers.add_offset_dims(ndims); + dim_numbers.add_start_index_map(ndims - 2); + dim_numbers.add_start_index_map(ndims - 1); dim_numbers.set_index_vector_dim(1); - diag_blocks = Gather(a, gather_indices, dim_numbers, - /*window_bounds=*/{block_size, block_size}); + diag_blocks = Gather(a, start_indices, dim_numbers, + /*slice_sizes=*/{block_size, block_size}); } // The last block might be smaller than the block size, @@ -110,8 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp InvertDiagonalBlocks( + xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - auto update = -DotGeneral(input_row, body_out, dnums); + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); body_out = DynamicUpdateSlice(body_out, update, start_indices); @@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, - xla::XlaOp inv_diag_blocks, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp SolveWithInvertedDiagonalBlocks( + xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false); + remainder = b_row - BatchDot(a_row, x, transpose_a, false, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a); + remainder = b_row - BatchDot(x, a_row, false, transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } } @@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::ConstantR0WithType(builder, xla::S32, j * block_size); std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = BatchDot(inv_block, remainder, transpose_a, false); + x_update = + BatchDot(inv_block, remainder, transpose_a, false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); } else { - x_update = BatchDot(remainder, inv_block, false, transpose_a); + x_update = + BatchDot(remainder, inv_block, false, transpose_a, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size) { + int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = - InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, + conjugate_a, precision); // We now find the solution using GEMMs - auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, - lower, transpose_a, conjugate_a); + auto x = + SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, + transpose_a, conjugate_a, precision); return x; }); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 555760b7efabddfb25c9135b109a1c48b487415e..ac42a4835295b7cb52697710d738f4728d3983d1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -59,7 +59,9 @@ namespace tensorflow { // blocking is used. xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128); + int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 8b5beba383cda45d36e2ee27ca5e3b3c5988b6b7..c26784852472061ffead03cfe7431f8b8ba0e555 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -113,8 +113,8 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, return xla::ConstantLiteral(builder, literal); } -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end) { +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_RET_CHECK(start.size() == end.size()); @@ -124,9 +124,10 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - n_minor_dims); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); // Prepends 0s in the major dim std::vector padded_start(n_dims, 0); @@ -143,8 +144,8 @@ xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, }); } -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys) { +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { std::vector output(xs.size() + ys.size()); std::copy(xs.begin(), xs.end(), output.begin()); std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); @@ -152,8 +153,8 @@ std::vector ConcatVectors(gtl::ArraySlice xs, } xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes) { + absl::Span starts, + absl::Span sizes) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -161,9 +162,10 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, int64 n_minor_dims = starts.size(); TF_RET_CHECK(n_minor_dims == sizes.size()); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), - /*pos=*/0, - /*len=*/n_dims - sizes.size()); + auto major_dims = xla::AsInt64Slice(shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - sizes.size()); auto padded_starts = PrependZerosInMajorDims(x, starts); auto padded_sizes = ConcatVectors(major_dims, sizes); return xla::DynamicSlice(x, padded_starts, padded_sizes); @@ -171,7 +173,7 @@ xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, } xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // TODO(phawkins): make int64 work on all backends, remove the int32 cast. @@ -189,7 +191,7 @@ xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start) { + absl::Span start) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); @@ -204,13 +206,13 @@ xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, } xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts) { + absl::Span starts) { auto padded_starts = PrependZerosInMajorDims(x, starts); return xla::DynamicUpdateSlice(x, update, padded_starts); } xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts) { + absl::Span starts) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index b4905c952820a45371e090aa98466654e2db9661..80e9e5b002d49581209e608b98606e02709c5876 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -31,7 +31,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros // prepended until the array is length n_dims. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. @@ -41,33 +41,33 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, // Builds a vector of zeros of length rank(x) with the last values being // those in `starts`. xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, - gtl::ArraySlice starts); + absl::Span starts); // Performs a slice in the minor dimensions of a Tensor. -xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice start, - gtl::ArraySlice end); +xla::XlaOp SliceInMinorDims(xla::XlaOp x, absl::Span start, + absl::Span end); // Returns the concatenation of `xs` and `ys`. -std::vector ConcatVectors(gtl::ArraySlice xs, - gtl::ArraySlice ys); +std::vector ConcatVectors(absl::Span xs, + absl::Span ys); // Performs a dynamic slice in the minor dimensions of a Tensor. xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x, - gtl::ArraySlice starts, - gtl::ArraySlice sizes); + absl::Span starts, + absl::Span sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice start); + absl::Span start); xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update, - gtl::ArraySlice starts); + absl::Span starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. xla::XlaOp TransposeInMinorDims(xla::XlaOp x); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index d64394f1401d7ceea004a59c991ef6f4a1c58b41..5300e2c878bf725b65544701eb3fdc6032553491 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -84,15 +84,15 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder) { auto while_cond_fn = - [&](gtl::ArraySlice values, + [&](absl::Span values, xla::XlaBuilder* cond_builder) -> xla::StatusOr { return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, + auto while_body_fn = [&](absl::Span values, xla::XlaBuilder* body_builder) -> xla::StatusOr> { xla::XlaOp iteration = values[0]; diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 9493b1f109be0725f7f733b9f9da664264275a69..115ebf390df6c215680e5982a6ceba546f384af8 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,24 +19,24 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function(gtl::ArraySlice, +typedef std::function(absl::Span, xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. typedef std::function>( - gtl::ArraySlice, xla::XlaBuilder*)> + absl::Span, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -59,13 +59,13 @@ xla::StatusOr> XlaWhileLoop( // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. typedef std::function>( - xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> + xla::XlaOp, absl::Span, xla::XlaBuilder*)> ForEachIndexBodyFunction; xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, StringPiece name, + absl::Span initial_values, StringPiece name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 77da1bf29ced60e490f07abad41cf8ce96232982..20103ec3ae00b57723e05326dbbb1b0f6e1a671a 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -49,9 +49,8 @@ Status HostTensorToMutableBorrowingLiteral( return Status::OK(); } -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal) { +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal) { std::vector buf_ptrs; buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 09d6fa811669b422532673540e4da47f47e6be4e..1db7470ee2a839099454b772d4833492e033bc92 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -18,11 +18,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -43,9 +43,8 @@ Status HostTensorToMutableBorrowingLiteral( // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. -Status HostTensorsToBorrowingLiteralTuple( - tensorflow::gtl::ArraySlice host_tensors, - xla::BorrowingLiteral* literal); +Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index a3404c2b3df7bf25011359d1f5f5b88c29a3f83b..7dc16b5a46791b81eef2c572736e1a1c7969b203 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -28,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -49,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + xla::LiteralUtil::CreateR1(absl::Span(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ace6fd1d8eeaf439509a7b75d8d986997c392e73..4dce0a2102cf9c782850ccc7af4f14b59bd51e53 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a59c77f5c3a309abe8f6fbab1e48455d54e8fae5..2cd9ae799f06afdcbae5429ef8caffd3b4d29c29 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,11 +13,97 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace { + +// Helper shape function for operators that return an output with the same rank +// as their first input. +Status UnchangedRank(shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); +} + +REGISTER_OP("XlaBroadcastHelper") + .Input("lhs: T") + .Input("rhs: T") + .Input("broadcast_dims: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("lhs_output: T") + .Output("rhs_output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Helper operator for performing XLA-style broadcasts + +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + +lhs: the LHS input tensor +rhs: the RHS input tensor +broadcast_dims: an XLA-style broadcast dimension specification +lhs_output: the broadcasted LHS tensor +rhs_output: the broadcasted RHS tensor +)doc"); + +REGISTER_OP("XlaConv") + .Input("lhs: T") + .Input("rhs: T") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); + +REGISTER_OP("XlaDot") + .Input("lhs: T") + .Input("rhs: T") + .Attr("T: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") @@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("XlaPad") + .Input("input: T") + .Input("padding_value: T") + .Input("padding_low: Tindices") + .Input("padding_high: Tindices") + .Input("padding_interior: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Pad operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + +input: A `Tensor` of type T. +padding_value: A scalar `Tensor` of type T. +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +padding_interior: the padding to apply between each input element. +output: A `Tensor` of type T. +)doc"); + REGISTER_OP("XlaRecv") .Output("tensor: dtype") .Attr("dtype: type") @@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel. shape: The shape of the tensor. )doc"); +REGISTER_OP("XlaReduce") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + if (rank < dimensions_to_reduce.size() || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce"); + } + c->set_output( + 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(UnchangedRank) .Doc(R"doc( Wraps the XLA ReduceWindow operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . @@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction computation: a reducer function to apply window_dimensions: the shape of the window window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. +padding: the padding to apply at the start and end of each input dimensions +)doc"); + +REGISTER_OP("XlaSelectAndScatter") + .Input("operand: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("source: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("select: func") + .Attr("scatter: func") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA SelectAndScatter operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter +. + +operand: the input tensor +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +source: a tensor of values to scatter +init_value: a scalar representing the initial value for the output tensor +select: a selection function to apply +scatter: a scatter function to apply )doc"); REGISTER_OP("XlaSend") @@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 42b6292f79ffddd155c05758a1420a2a583eb0c6..69ca39436013ec5cf09ba502a1540d5df322e213 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -28,5 +28,6 @@ py_library( srcs = ["xla.py"], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58..3626de375ea9ac12e40ea5b5b591bb6d5262adbc 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -15,11 +15,12 @@ """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from -TensorFlow. This file provides Tensorflow operators that map as closely as -possible to HLO operators. +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. -There is no promise of backward or forward compatibility for operators defined -in this module. +Note: There is no promise of backward or forward compatibility for operators +defined in this module. This is primarily because the underlying HLO operators +do not promise backward or forward compatibility. """ from __future__ import absolute_import @@ -27,11 +28,298 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse + +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +# TODO(phawkins): implement erfinv +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tanh = _unary_op(math_ops.tanh) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), + array_ops.shape(x)], axis=0) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv(lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + name=None): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `PrecisionConfigProto` proto. + name: an optional name for the operator + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +def dynamic_slice(x, starts, sizes, name=None): + # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not + # a compile-time constant. This doesn't exactly mimic the semantics of dynamic + # slice if the slice is out of bounds. + return array_ops.slice(x, starts, sizes, name=name) -# TODO(phawkins): provide wrappers for all XLA operators. dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce + def reduce_window(operand, init, @@ -61,22 +349,38 @@ def reduce_window(operand, """ window_strides = window_strides or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) - padding_low = [x for (x, _) in padding] - padding_high = [y for (_, y) in padding] return gen_xla_ops.xla_reduce_window( - operand, - init, - reducer, - window_dimensions, - window_strides, - padding_low, - padding_high, + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + computation=reducer, name=name) -recv = gen_xla_ops.xla_recv +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send -sort = gen_xla_ops.xla_sort +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sort = gen_xla_ops.xla_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..32ba6df2e6daa2add468a1bc0559d42606d1a9a6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "absl/algorithm/container.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( + XlaResourceOpKind op_kind) { + switch (op_kind) { + case XlaResourceOpKind::kRead: + return "Read"; + case XlaResourceOpKind::kWrite: + return "Write"; + case XlaResourceOpKind::kReadWrite: + return "Modify"; + } +} + +static gtl::FlatMap* CreateResourceOpInfoMap() { + gtl::FlatMap* result = + new gtl::FlatMap; + + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) { + auto insert_result = + result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); + CHECK(insert_result.second); + }; + + auto kRead = XlaResourceOpKind::kRead; + auto kWrite = XlaResourceOpKind::kWrite; + auto kReadWrite = XlaResourceOpKind::kReadWrite; + + auto kVariable = XlaResourceKind::kVariable; + auto kStack = XlaResourceKind::kStack; + auto kTensorArray = XlaResourceKind::kTensorArray; + + // clang-format off + add("AssignAddVariableOp" , kReadWrite, kVariable); + add("AssignSubVariableOp" , kReadWrite, kVariable); + add("AssignVariableOp" , kWrite, kVariable); + add("ReadVariableOp" , kRead, kVariable); + add("ResourceApplyAdaMax" , kReadWrite, kVariable); + add("ResourceApplyAdadelta" , kReadWrite, kVariable); + add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradDA" , kReadWrite, kVariable); + add("ResourceApplyAdam" , kReadWrite, kVariable); + add("ResourceApplyAddSign" , kReadWrite, kVariable); + add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable); + add("ResourceApplyFtrl" , kReadWrite, kVariable); + add("ResourceApplyFtrlV2" , kReadWrite, kVariable); + add("ResourceApplyGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyPowerSign" , kReadWrite, kVariable); + add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); + add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyRMSProp" , kReadWrite, kVariable); + add("ResourceGather" , kRead, kVariable); + add("ResourceScatterAdd" , kReadWrite, kVariable); + add("ResourceScatterDiv" , kReadWrite, kVariable); + add("ResourceScatterMax" , kReadWrite, kVariable); + add("ResourceScatterMin" , kReadWrite, kVariable); + add("ResourceScatterMul" , kReadWrite, kVariable); + add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdUpdate" , kReadWrite, kVariable); + add("ResourceScatterSub" , kReadWrite, kVariable); + add("ResourceScatterUpdate" , kReadWrite, kVariable); + add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("VarIsInitializedOp" , kRead, kVariable); + add("VariableShape" , kRead, kVariable); + + add("StackV2" , kWrite, kStack); + add("StackCloseV2" , kRead, kStack); + add("StackPopV2" , kReadWrite, kStack); + add("StackPushV2" , kReadWrite, kStack); + + add("TensorArrayV3" , kWrite, kTensorArray); + add("TensorArrayConcatV3" , kRead, kTensorArray); + add("TensorArrayGatherV3" , kRead, kTensorArray); + add("TensorArrayScatterV3" , kWrite, kTensorArray); + add("TensorArrayGradV3" , kRead, kTensorArray); + add("TensorArrayCloseV3" , kRead, kTensorArray); + add("TensorArrayReadV3" , kRead, kTensorArray); + add("TensorArraySizeV3" , kRead, kTensorArray); + add("TensorArraySplitV3" , kWrite, kTensorArray); + add("TensorArrayWriteV3" , kWrite, kTensorArray); + // clang-format on + + return result; +} + +static const gtl::FlatMap& +GetStaticResourceOpInfoMap() { + static gtl::FlatMap* op_info_map = + CreateResourceOpInfoMap(); + return *op_info_map; +} + +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { + const gtl::FlatMap& op_infos = + GetStaticResourceOpInfoMap(); + auto it = op_infos.find(op); + return it == op_infos.end() ? nullptr : &it->second; +} + +namespace resource_op_table_internal { +std::vector GetKnownResourceOps() { + std::vector result; + for (const auto& p : GetStaticResourceOpInfoMap()) { + result.push_back(p.first); + } + absl::c_sort(result); + return result; +} +} // namespace resource_op_table_internal +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..7f627a64c6e8298a427cd87d25d4ba24835bf542 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0343f80de9fed114a0097b981233277c3e12b378 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} + +bool HasResourceInputOrOutput(const OpDef& op_def) { + return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) || + absl::c_any_of(op_def.output_arg(), IsResourceArgDef); +} + +TEST(ResourceOperationTableTest, HaveAllResourceOps) { + gtl::FlatMap known_resource_ops; + for (StringPiece known_resource_op : + resource_op_table_internal::GetKnownResourceOps()) { + ASSERT_TRUE( + known_resource_ops.insert({string(known_resource_op), false}).second); + } + + std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const string& xla_op_name : xla_op_names) { + const OpDef* op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); + if (HasResourceInputOrOutput(*op_def)) { + EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) + << "Unknown resource op " << xla_op_name; + known_resource_ops[xla_op_name] = true; + } + } + + std::vector unnecessary_resource_ops; + for (const auto& pair : known_resource_ops) { + if (!pair.second) { + unnecessary_resource_ops.push_back(pair.first); + } + } + + EXPECT_TRUE(unnecessary_resource_ops.empty()) + << "Stale resource ops:\n" + << absl::StrJoin(unnecessary_resource_ops, "\n"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 5759c72af301785f3ca1110b58eeb2fe7dead713..2d7eb8b915b8245ba6573c30b2eb15b12fc3a1b4 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -27,10 +27,10 @@ const char kShardingAttribute[] = "_XlaSharding"; } // namespace namespace { -xla::StatusOr> -GetShardingFromNodeDef(const NodeDef& node_def) { +xla::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def) { if (!HasNodeAttr(node_def, kShardingAttribute)) { - return tensorflow::gtl::optional(); + return absl::optional(); } string value; xla::OpSharding sharding; @@ -40,7 +40,7 @@ GetShardingFromNodeDef(const NodeDef& node_def) { "Experimental _XlaSharding attribute was not a valid encoded " "xla::OpSharding proto."); } - return tensorflow::gtl::optional(sharding); + return absl::optional(sharding); } Status CoreOutOfRangeError(int core, int num_cores_per_replica) { @@ -50,12 +50,11 @@ Status CoreOutOfRangeError(int core, int num_cores_per_replica) { } } // namespace -xla::StatusOr> -ParseShardingFromDevice( +xla::StatusOr> ParseShardingFromDevice( const string& device_name, int num_cores_per_replica, - tensorflow::gtl::optional explicit_sharding) { + absl::optional explicit_sharding) { if (device_name.empty()) { - return tensorflow::gtl::optional(); + return absl::optional(); } DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { @@ -66,34 +65,34 @@ ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !str_util::StrContains(parsed_device.type, - kDeviceSuffixReplicatedCore)) { - return tensorflow::gtl::optional(); + !absl::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { + return absl::optional(); } else { const int core = parsed_device.id; if (core < 0 || core >= num_cores_per_replica) { return CoreOutOfRangeError(core, num_cores_per_replica); } - return tensorflow::gtl::optional( + return absl::optional( xla::sharding_builder::AssignDevice(core)); } } -xla::StatusOr> -ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { +xla::StatusOr> ParseShardingFromDevice( + const NodeDef& node_def, int num_cores_per_replica) { const string& device_name = node_def.device(); - TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + TF_ASSIGN_OR_RETURN(absl::optional sharding, GetShardingFromNodeDef(node_def)); return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } -xla::StatusOr> -ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { +xla::StatusOr> ParseShardingFromDevice( + const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + TF_ASSIGN_OR_RETURN(absl::optional sharding, GetShardingFromNodeDef(node.def())); return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index b1c817bdcc211648b16e395313ca171d1acb9ea9..ab67d4f154282e3fc37b68339045deb5da91b9db 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ #include @@ -33,19 +33,18 @@ namespace tensorflow { // - explicit_sharding if explicit_sharding.has_value() // - a non-value if there is no assigned core or // - a sharding set as per xla::sharding_builder::AssignDevice. -xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, - tensorflow::gtl::optional - explicit_sharding = tensorflow::gtl::nullopt); +xla::StatusOr> ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + absl::optional explicit_sharding = absl::nullopt); -xla::StatusOr> -ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromDevice( + const Node& node, int num_cores_per_replica); -xla::StatusOr> -ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromDevice( + const NodeDef& node_def, int num_cores_per_replica); void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index bff5978237a827cb9650541f2cf6984d9e846796..dcb7e212b74d2e261de7e125bb66b3ec78e0cfe9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -23,7 +23,7 @@ TEST(CoreUtilTest, ParseShardingFromDevice) { Graph graph(OpRegistry::Global()); auto core_from_sharding = - [](tensorflow::gtl::optional sharding) -> int64 { + [](absl::optional sharding) -> int64 { if (sharding.has_value() && sharding.value().type() == xla::OpSharding::Type::OpSharding_Type_MAXIMAL) { diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc deleted file mode 100644 index 2b0834fe7b6c4d2199267dbe0ec1f7c2785aa9c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -namespace tensorflow { -namespace str_util { - -static void ReplaceAll(string* text, StringPiece from, StringPiece to) { - size_t pos = 0; - while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { - text->replace(pos, from.size(), to.data(), to.size()); - pos += to.size(); - if (from.empty()) { - pos++; // Match at the beginning of the text and after every byte - } - } -} - -void ReplaceAllPairs(string* text, - const std::vector>& replace) { - for (const std::pair& from_to : replace) { - ReplaceAll(text, from_to.first, from_to.second); - } -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h deleted file mode 100644 index 51f25009d7003db0d72296619a469ecbbbb1808d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// String utilities that are esoteric enough that they don't belong in -// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally -// useful under xla. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace tensorflow { -namespace str_util { - -// Replace all non-overlapping occurrences of the given (from,to) pairs in-place -// in text. If from is empty, it matches at the beginning of the text and after -// every byte. Each (from,to) replacement pair is processed in the order it is -// given. -void ReplaceAllPairs(string* text, - const std::vector>& replace); - -} // namespace str_util -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc deleted file mode 100644 index 8817f6902a8e58e796ca5240a9a24d7506d38793..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace str_util { - -class ReplaceAllPairsTest : public ::testing::Test { - protected: - void ExpectReplaceAllPairs( - string text, const std::vector>& replace, - StringPiece want) { - ReplaceAllPairs(&text, replace); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllPairsTest, Simple) { - ExpectReplaceAllPairs("", {}, ""); - ExpectReplaceAllPairs("", {{"", ""}}, ""); - ExpectReplaceAllPairs("", {{"", "X"}}, "X"); - ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_"); - ExpectReplaceAllPairs("banana", {}, "banana"); - ExpectReplaceAllPairs("banana", {{"", ""}}, "banana"); - ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_"); - ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__"); - ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana"); - ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn"); - ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX"); - ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}", - {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}}, - "a0b123456789c0"); -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 48568c825b7a0f13011d3d6e8e62ec5db026760f..f34af2d67debe8bfa4abcad19e42c55ea40c4e82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -197,8 +197,8 @@ Status RewriteAndPruneGraph( if (!missing_feeds.empty() || !missing_fetches.empty()) { return errors::Aborted( "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), + ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 7aca889a266439538c4cd1c153460e6cc871b246..567d212b5eee493d29a1817987cbd7759575386e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } std::sort(types.begin(), types.end()); constraints.push_back("`" + constraint.name() + "={" + - str_util::Join(types, ",") + "}`"); + absl::StrJoin(types, ",") + "}`"); } std::cout << "`" << kdef->op() << "` | " - << str_util::Join(constraints, "
") << std::endl; + << absl::StrJoin(constraints, "
") << std::endl; } std::cout << "\nTo regenerate this table, run:\n\n```shell\n" @@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + - str_util::Join(device_names, ",")}, + absl::StrJoin(device_names, ",")}, }; string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 0e07485d1861aa40b14e527b14947c6f8bab647e..e284e0b191ac09f9491973166c80b731c8ea51a5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -233,7 +233,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = std::string(id.first); + const string node_name = string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); @@ -268,7 +268,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { if (edge->IsControlEdge()) continue; const Node* possible_match = out_edges ? edge->dst() : edge->src(); TF_ASSIGN_OR_RETURN( - tensorflow::gtl::optional sharding, + absl::optional sharding, ParseShardingFromDevice( *possible_match, /*num_cores_per_replica=*/std::numeric_limits::max())); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ae51446204baf14dc03fc6305641048dbf3872b0..2b1f724dc7b2e2bb6d06115827f92bf0670955b3 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,16 +26,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index e89f4733281194f0263ae8cc4907caa0ad781165..d98237bd5c9288e6337e10c19c2d7574ad2e4c97 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -103,7 +103,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, auto sharding_parse_result = ParseShardingFromDevice( op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); - tensorflow::gtl::optional op_sharding = + absl::optional op_sharding = sharding_parse_result.ValueOrDie(); // If no sharding metadata is found, XLA is free to use whatever device it diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 226c89bcf1e66b5afb43cddb03db39b931ca55a8..0c300c282e9698534af6372b2f2ddae06f88db24 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -310,7 +311,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); Status status; - auto step_container = xla::MakeUnique( + auto step_container = absl::make_unique( step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); }); @@ -360,6 +361,9 @@ Status BuildComputation( if (retval.has_constant_value()) { output.is_constant = true; output.constant_value = retval.constant_value(); + } else if (retval.resource() != nullptr) { + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); } else { output.is_constant = false; elems.push_back(retval.handle()); @@ -413,7 +417,7 @@ Status BuildComputation( // Request that the value be returned on a specific core. xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); xla::XlaOp handle; @@ -464,8 +468,6 @@ Status XlaCompiler::BuildArguments( // XLA computation as runtime parameters. input_mapping->clear(); input_mapping->reserve(args.size()); - std::vector resources; - resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -484,8 +486,9 @@ Status XlaCompiler::BuildArguments( /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { - resources.push_back(i); + input_mapping->push_back(i); } + break; case XlaCompiler::Argument::kParameter: { input_mapping->push_back(i); @@ -495,14 +498,11 @@ Status XlaCompiler::BuildArguments( arg_expression.set_constant_value(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling constant args"); } } - // Append parameters containing variable values after the other runtime - // parameters. - input_mapping->insert(input_mapping->end(), resources.begin(), - resources.end()); if (input_mapping->empty()) { return Status::OK(); } @@ -570,7 +570,7 @@ Status XlaCompiler::BuildArguments( for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::GetTupleElement(tuple, i); } @@ -578,7 +578,7 @@ Status XlaCompiler::BuildArguments( for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() + builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], strings::StrCat("arg", i)); @@ -619,7 +619,8 @@ Status XlaCompiler::BuildArguments( break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling handles"); } } @@ -791,14 +792,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); - // Copy the host transfer metadata to the result. - for (const auto& send : host_compute_sends_) { - *result->host_compute_metadata.add_device_to_host() = send.second; - } - for (const auto& recv : host_compute_recvs_) { - *result->host_compute_metadata.add_host_to_device() = recv.second; - } - // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); @@ -816,10 +809,34 @@ Status XlaCompiler::GetChannelHandle(const string& key, return Status::OK(); } +Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateHostToDeviceChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + +Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel) { + auto result = channels_.emplace(key, xla::ChannelHandle()); + if (result.second) { + TF_ASSIGN_OR_RETURN(result.first->second, + client()->CreateDeviceToHostChannelHandle()); + } + *channel = result.first->second; + VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); + return Status::OK(); +} + namespace { -void SetTransfer(const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes, +void SetTransfer(const string& key, absl::Span types, + absl::Span shapes, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); @@ -833,8 +850,8 @@ void SetTransfer(const string& key, gtl::ArraySlice types, } // namespace Status XlaCompiler::SetDeviceToHostMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetDeviceToHostMetadata with key ", key); @@ -860,8 +877,8 @@ Status XlaCompiler::GetDeviceToHostShapes( } Status XlaCompiler::SetHostToDeviceMetadata( - const string& key, gtl::ArraySlice types, - gtl::ArraySlice shapes) { + const string& key, absl::Span types, + absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { return errors::InvalidArgument( "Duplicate calls to SetHostToDeviceMetadata with key ", key); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 25332c8d8e3210a0217a1ba3f5767115fe6b1d93..8f4a9858ed63403b9d0f967b61d3f690f12df21a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -183,6 +183,8 @@ class XlaCompiler { struct OutputDescription { // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. DataType type; TensorShape shape; @@ -190,6 +192,10 @@ class XlaCompiler { // 'Tensor' is in host memory. bool is_constant = false; Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; }; // Describes a variable write side effect of the computation. @@ -212,9 +218,9 @@ class XlaCompiler { struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their - // original argument positions. To handle compile-time constant inputs and - // resources, the parameters to the XLA computation may be a subset of the - // original arguments, and are not necessarily in the same order.) + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. std::vector input_mapping; // Input shapes of the computation. If we are flattening inputs, these are @@ -332,11 +338,21 @@ class XlaCompiler { // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + // Retrieves the host-to-device channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel); + + // Retrieves the device-to-host channel handle associated with `key`. + // Allocates a new channel handle if none exists. + Status GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel); + // Sets the shapes and types for the device to host transfer associated with // 'key'. Status SetDeviceToHostMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. Status GetDeviceToHostShapes(const string& key, @@ -345,8 +361,8 @@ class XlaCompiler { // Sets the shapes and types for the host to device transfer associated with // 'key'. Status SetHostToDeviceMetadata(const string& key, - gtl::ArraySlice types, - gtl::ArraySlice shapes); + absl::Span types, + absl::Span shapes); // In order to avoid deadlocks from dependencies in host computations, it can // be necessary to enforce a partial order on the execution of HostCompute diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be00ed8813fdf2778d6af81556001ef51538dd34..be3c93ae47bf16a67ed4fac34a99997cc7888559 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -280,6 +280,54 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); } +// Tests that the compiler doesn't reorder the parameters. +TEST_F(XlaCompilerTest, MixedOrderArguments) { + for (bool swap_order : {false, true}) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = + ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0); + // Adds an identity op around the resource to make sure identity ops + // propagate resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + if (swap_order) { + // Even after swapping arguments, the compiler should maintain the new + // ordering of parameters. + std::swap(args[0], args[1]); + } + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1)); + } +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. @@ -309,10 +357,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "depends on a parameter")) + absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -727,8 +775,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); } @@ -807,21 +854,49 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "Attr T is not found")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) << status.error_message(); } +void RunAndCheckVariablesComputation( + xla::Client* client, const XlaCompiler::CompilationResult& result) { + std::unique_ptr param0_literal = + xla::LiteralUtil::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::LiteralUtil::CreateR1({5, 144}); + std::unique_ptr expected1 = + xla::LiteralUtil::CreateR1({4, 143}); + std::unique_ptr expected_literal = + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, Variables) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); - auto write = ops::AssignAddVariableOp(scope, var, a); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); auto read = ops::ReadVariableOp( scope.WithControlDependencies(std::vector{write}), var, DT_INT32); @@ -844,36 +919,90 @@ TEST_F(XlaCompilerTest, Variables) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0); + auto d = ops::_Retval(scope.WithOpName("D"), var, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kVariable; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_ - ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +TEST_F(XlaCompilerTest, ReturnResourceHandle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto r = ops::_Retval(scope.WithOpName("R"), var, 0); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); @@ -1075,9 +1204,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1100,10 +1229,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "is not in the list of allowed values")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "is not in the list of allowed values")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) << status.error_message(); } @@ -1127,9 +1256,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5..24a4b92b45a3f3563e435fa074fce595d6c0b263 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -107,6 +107,19 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, return Status::OK(); } +Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { + VLOG(1) << "Adding retval index " << retval_index << " with resource " + << resource->name() << ":" << resource->shape().DebugString() + << " to XLA computation"; + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + XlaExpression e; + e.set_resource(resource); + retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3db37afdba71342cfb20af8841a40cb54709ca73..4da891634e97dd67af0ef09ef33dbc7a4d19743b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -86,6 +86,9 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::LiteralSlice& literal); + // As for Retval, but for return values that are resource handles. + Status AddResourceRetval(int retval_index, XlaResource* resource); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 8efb3d55c88757b9366bdf9622287bdd0a72e295..9a34cd8c6ae2dc6d52a3cc69168df96f5322c6da 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -119,7 +119,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, } /* static */ Status XlaHelpers::ReshapeLiteral( - const xla::Literal& input, gtl::ArraySlice dimensions, + const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { if (xla::ShapeUtil::IsTuple(input.shape())) { return errors::InvalidArgument("ReshapeLiteral does not support tuples."); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index e6522157a535fc3e4ec96cb0496b6be2e525c336..39578144caaadf293d24ea91aa874e56e27ecc01 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -18,10 +18,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -50,7 +50,7 @@ class XlaHelpers { // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. static Status ReshapeLiteral(const xla::Literal& input, - gtl::ArraySlice shape, + absl::Span shape, xla::Literal* output); // Returns the argmax of `input` along `axis`. `output_type` is the type to diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 82028c8b9ca9f65a73f8b50edc0a47c7068aba9a..1499c99ed15eceaf6bfa2ef0dd1d5885b1e5fc58 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -99,8 +99,27 @@ Status XlaOpKernelContext::ConstantInput(int index, index, context_->input(index).shape().dim_sizes(), constant_literal); } +static xla::StatusOr InputIndex(XlaOpKernelContext* context, + StringPiece name) { + int start, stop; + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + return start; +} + +Status XlaOpKernelContext::ConstantInput(StringPiece name, + xla::Literal* constant_literal) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInput(index, constant_literal); +} + Status XlaOpKernelContext::ConstantInputReshaped( - int index, gtl::ArraySlice new_dims, + int index, absl::Span new_dims, xla::Literal* constant_literal) { const Tensor& tensor = context_->input(index); TensorShape new_shape(new_dims); @@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, + int64* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntScalar(index, out); +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -280,6 +305,20 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, + std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntVector(index, out); +} + +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + int index, std::vector* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -305,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, + xla::Literal* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsInt64Literal(index, out); +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index ac9dfe3369078df7392a4ef04679f7d7beacf8bb..45cfa7da740c38afde0158568a019a4426992b64 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -106,26 +106,34 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); + Status ConstantInput(StringPiece name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped(int index, gtl::ArraySlice new_shape, + Status ConstantInputReshaped(int index, absl::Span new_dims, xla::Literal* constant_literal); // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); + Status ConstantInputAsIntScalar(StringPiece name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + Status ConstantInputAsIntVector(StringPiece name, std::vector* out); + + // Reshapes and converts a constant int32 or int64 tensor into a vector of + // int64s. + Status ConstantInputReshapedToIntVector(int index, std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 46785bc1f0a1279bfd67a55844fe238d9797382b..dae2d956ca61a18f7da61fcd0a569a55a6286663 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -105,7 +105,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; /* static */ void XlaOpRegistry::RegisterBackend( const string& compilation_device_name, - gtl::ArraySlice supported_types, BackendOpFilter op_filter) { + absl::Span supported_types, BackendOpFilter op_filter) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto result = registry.backends_.emplace(compilation_device_name, Backend()); @@ -325,6 +325,17 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { + std::vector ops; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& pair : registry.ops_) { + ops.push_back(pair.first); + } + std::sort(ops.begin(), ops.end()); + return ops; +} + /* static */ const std::unordered_set* XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); @@ -362,7 +373,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = std::string(name); + registration_->name = string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -371,17 +382,17 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - gtl::ArraySlice devices) { + absl::Span devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); return *this; } @@ -398,15 +409,15 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; types.insert(allowed); return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, gtl::ArraySlice allowed) { + StringPiece attr_name, absl::Span allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -415,7 +426,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(std::string(input_name)); + registration_->compile_time_constant_inputs.emplace(input_name); return *this; } @@ -441,10 +452,10 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, gtl::ArraySlice types, + StringPiece name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(std::string(name), types, op_filter); + registry.RegisterBackend(string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index fc14834ca6441ea785eacc57e1f502086f36657e..c640842dc0d4fb3aff64d8388b4ffd3fdcee9faf 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -94,7 +94,7 @@ class XlaOpRegistry { // the device; it may optionally modify the KernelDef. typedef bool (*BackendOpFilter)(KernelDef* kdef); static void RegisterBackend(const string& compilation_device_name, - gtl::ArraySlice supported_types, + absl::Span supported_types, BackendOpFilter op_filter); // Returns the names of the registered backends. @@ -128,6 +128,9 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns all operations for which there are XLA kernels on any device. + static std::vector GetAllRegisteredOps(); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr // if the op is not registered. static const std::unordered_set* CompileTimeConstantInputs( @@ -233,7 +236,7 @@ class XlaOpRegistrationBuilder { // Specifies a whitelist of devices on which the operator may run. XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(gtl::ArraySlice devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. @@ -241,7 +244,7 @@ class XlaOpRegistrationBuilder { DataType allowed); XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, - gtl::ArraySlice allowed); + absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on // XLA_* devices, but may be used during compilation. @@ -285,7 +288,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, gtl::ArraySlice types, + XlaBackendRegistrar(StringPiece name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index fdf13bb18c2567d2994612d15119ae87cbfa9137..76e36f3c46b22742b6cf0c86e89d17899338a60f 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -113,6 +113,7 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -161,7 +162,6 @@ cc_library( "iterator_util.h", "map_util.h", "overflow_util.h", - "ptr_util.h", "util.h", ], visibility = ["//visibility:public"], @@ -172,7 +172,11 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -210,6 +214,7 @@ tf_cc_test( ":test", ":util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -236,10 +241,13 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -256,6 +264,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -297,6 +306,10 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -315,6 +328,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -335,6 +350,9 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -353,6 +371,8 @@ cc_library( ":literal_util", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -364,6 +384,8 @@ cc_library( deps = [ ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -373,8 +395,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -385,6 +407,8 @@ cc_library( ":status", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -405,8 +429,9 @@ cc_library( deps = [ ":array", ":types", - ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -451,6 +476,8 @@ cc_library( ":array2d", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -462,6 +489,7 @@ tf_cc_test( ":test", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/types:span", ], ) @@ -489,6 +517,8 @@ cc_library( ":util", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/base", + "@com_google_absl//absl/memory", ], ) @@ -503,6 +533,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -521,6 +552,8 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -551,6 +584,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -576,10 +611,12 @@ cc_library( deps = [ ":shape_util", ":status_macros", - ":util", ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -593,6 +630,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -619,6 +657,8 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -642,6 +682,8 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -660,6 +702,7 @@ tf_cc_test( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -671,7 +714,8 @@ cc_library( ":array2d", ":shape_util", ":xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 2d5d078aa77423cc18bab053b80a7576acbd849e..58cc1575858201b4508d7340cb47e59c4f4c5783 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -27,12 +27,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -97,12 +97,11 @@ class Array { using value_type = T; // Creates a new array with the specified dimensions. - explicit Array(tensorflow::gtl::ArraySlice sizes) - : Array(sizes, T()) {} + explicit Array(absl::Span sizes) : Array(sizes, T()) {} // Creates a new array with the specified dimensions and specified value for // every cell. - Array(tensorflow::gtl::ArraySlice sizes, T value) + Array(absl::Span sizes, T value) : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { Fill(value); } @@ -301,7 +300,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. - void Each(std::function, T*)> f) { + void Each(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, &values_[i]); @@ -309,8 +308,7 @@ class Array { } // Invokes a callback with the (indices, value) for each cell in the array. - void Each( - std::function, T)> f) const { + void Each(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { f(index, values_[i]); @@ -320,8 +318,7 @@ class Array { // Invokes a callback with the (indices, value_ptr) for each cell in the // array. If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T*)> f) { + Status EachStatus(std::function, T*)> f) { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, &values_[i]); @@ -335,8 +332,7 @@ class Array { // Invokes a callback with the (indices, value) for each cell in the array. // If a callback returns a non-OK status, returns that else returns // Status::OK(). - Status EachStatus( - std::function, T)> f) const { + Status EachStatus(std::function, T)> f) const { std::vector index(sizes_.size()); for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { Status s = f(index, values_[i]); @@ -377,13 +373,13 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - const T& operator()(tensorflow::gtl::ArraySlice indexes) const { + const T& operator()(absl::Span indexes) const { return values_[calculate_index(indexes)]; } // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. - T& operator()(tensorflow::gtl::ArraySlice indexes) { + T& operator()(absl::Span indexes) { return values_[calculate_index(indexes)]; } @@ -438,8 +434,8 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } // Performs the equivalent of a slice operation on this array. - Array Slice(tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits) const { + Array Slice(absl::Span starts, + absl::Span limits) const { CHECK_EQ(starts.size(), num_dimensions()); CHECK_EQ(limits.size(), num_dimensions()); @@ -464,7 +460,7 @@ class Array { // Performs the equivalent of a DynamicUpdateSlice in-place on this array. void UpdateSlice(const Array& from, - tensorflow::gtl::ArraySlice start_indices) { + absl::Span start_indices) { CHECK_EQ(from.num_dimensions(), num_dimensions()); std::vector limit_indices; std::transform(start_indices.begin(), start_indices.end(), @@ -484,7 +480,7 @@ class Array { // Performs an in-place reshape, modifying the dimensions but not the // underlying data. - void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + void Reshape(absl::Span new_dimensions) { int64 old_num_elements = num_elements(); sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); CHECK_EQ(num_elements(), old_num_elements); @@ -507,9 +503,7 @@ class Array { } } - pieces.push_back( - tensorflow::strings::AlphaNum(values_[calculate_index(index)]) - .data()); + pieces.push_back(absl::StrCat(values_[calculate_index(index)])); // Emit comma if it isn't the last element if (index.back() != sizes_.back() - 1) { @@ -527,7 +521,7 @@ class Array { } } } while (next_index(&index)); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } private: diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index a17e81f44832f272fd93dce9f854042b4a84fde4..782c966b4c57672d137569a318fb20ace14d493b 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,12 +24,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -101,7 +100,7 @@ class Array2D : public Array { template std::unique_ptr> MakeLinspaceArray2D(double from, double to, int64 n1, int64 n2) { - auto array = MakeUnique>(n1, n2); + auto array = absl::make_unique>(n1, n2); int64 count = n1 * n2; NativeT step = static_cast((count > 1) ? (to - from) / (count - 1) : 0); diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4..e23d317baf9aca7b3705a93d6be952fb9a17762b 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,13 +26,11 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 927733ea1eab43feff643c35535cc6d9ea59ba5a..918872a7a03a022c72d22dfb8f0da9e9d3820e41 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -27,8 +27,7 @@ namespace { // Given an Array4D and a 4-tuple index, computes the linear index into the // array idx represents. template -int64 Array4DLinearIndex(const Array4D& arr, - tensorflow::gtl::ArraySlice idx) { +int64 Array4DLinearIndex(const Array4D& arr, absl::Span idx) { EXPECT_EQ(4, idx.size()); return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() + idx[0] * arr.n2() * arr.n3() * arr.n4()); @@ -51,9 +50,8 @@ TEST(Array4dTest, FillCtor) { EXPECT_EQ(fullof7.n3(), 4); EXPECT_EQ(fullof7.n4(), 5); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); } TEST(Array4dTest, ContainerCtor) { @@ -69,7 +67,7 @@ TEST(Array4dTest, ContainerCtor) { EXPECT_EQ(arr.n3(), 4); EXPECT_EQ(arr.n4(), 5); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, int* cell) { + arr.Each([&arr](absl::Span idx, int* cell) { EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx)); }); } @@ -129,21 +127,19 @@ TEST(Array3dTest, InitializerListCtorHalf) { TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 7); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 7); }); fullof7.Fill(11); - fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { - EXPECT_EQ(*cell, 11); - }); + fullof7.Each( + [](absl::Span idx, int* cell) { EXPECT_EQ(*cell, 11); }); } TEST(Array4dTest, FillWithMultiples) { Array4D arr(2, 3, 4, 5); arr.FillWithMultiples(2.0f); - arr.Each([&arr](tensorflow::gtl::ArraySlice idx, float* cell) { + arr.Each([&arr](absl::Span idx, float* cell) { EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx)); }); } diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index e8356c9832d34135f5ffb1a5c7a9d6db6db3a051..2d0ac98bd4ee27004295c4189cb190bb2c9739c9 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -163,7 +163,7 @@ TEST(ArrayTest, Each) { arr.FillWithMultiples(1); int64 each_count = 0, each_sum = 0; - arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + arr.Each([&](absl::Span idx, int cell) { int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; EXPECT_EQ(lin_idx, cell); each_count++; diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index ad3fcee05b80181369bfdf3cdcdb5452ec9e7e89..f825f67b447514a416f3a49ac8aad9dcf505f5a7 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -71,12 +72,14 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -90,6 +93,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -104,7 +110,6 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", @@ -115,8 +120,9 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", "@llvm//:support", ], ) @@ -130,11 +136,11 @@ cc_library( ":xla_computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:support", ], ) @@ -159,6 +165,7 @@ cc_library( "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -186,6 +193,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", + "@com_google_absl//absl/memory", ], ) @@ -211,6 +219,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index d0ce5e8a6afa262d4cffdfe8431aab570ffd28df..8818f813127230d3b39d4b48d874b7cfb24b8abc 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -89,7 +89,7 @@ StatusOr> Client::TransferToServer( "TransferToServer request"); } - return MakeUnique(stub_, response.data()); + return absl::make_unique(stub_, response.data()); } Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, @@ -163,8 +163,7 @@ Status Client::ResetDevice() { } StatusOr> Client::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { TF_ASSIGN_OR_RETURN( @@ -212,8 +211,7 @@ StatusOr Client::LoadSnapshot(const HloSnapshot& module) { } StatusOr> Client::Execute( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { ExecuteGraphRequest request; @@ -248,11 +246,11 @@ StatusOr> Client::Execute( } } - return MakeUnique(stub_, response.output()); + return absl::make_unique(stub_, response.output()); } StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { ExecuteGraphParallelRequest request; for (const XlaComputationInstance& computation : computations) { @@ -278,7 +276,7 @@ StatusOr>> Client::ExecuteParallel( std::vector> outputs; for (size_t i = 0; i < computations.size(); ++i) { outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); + absl::make_unique(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { *computations[i].execution_profile = response.responses(i).profile(); } @@ -340,7 +338,7 @@ StatusOr>> Client::DeconstructTuple( std::vector> handles; for (auto& handle : response.element_handles()) { - handles.push_back(MakeUnique(stub_, handle)); + handles.push_back(absl::make_unique(stub_, handle)); } return std::move(handles); } @@ -369,7 +367,7 @@ StatusOr Client::GetComputationStats( StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); - return MakeUnique(result); + return absl::make_unique(result); } StatusOr Client::GetShape(const GlobalData& data) { @@ -400,7 +398,7 @@ StatusOr Client::ExecutionStatsAsString( int64 nanoseconds = profile.compute_time_ns(); int64 cycle_count = profile.compute_cycle_count(); double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( + return absl::StrCat( "[Execution Statistics] flop count: ", computation_stats.flop_count(), ", transcendental count: ", computation_stats.transcendental_count(), ", compute execution time: ", nanoseconds, " nsec", diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index be50cebfcc0e3c19002635dbd280b14048aa0c93..7960b078686e611a6439af495d266f9084992d29 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -53,7 +53,7 @@ class Client { // will be filled with profile data from the execution. StatusOr> Execute( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); @@ -82,7 +82,7 @@ class Client { // from each computation. // StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Requests device_count device handles available on the target. The returned // device handles are used to specify the devices to execute the computations @@ -134,7 +134,7 @@ class Client { // Execute() and Transfer(). StatusOr> ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions* execution_options = nullptr, ExecutionProfile* execution_profile = nullptr); diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index 803a9e40094391ba47ed27713f4538caf875c4f6..27b7fa7b29206affa9f9c2e4becd9e4ea66484ab 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -94,10 +95,10 @@ ClientLibrary::~ClientLibrary() = default; service_options.set_intra_op_parallelism_threads( options.intra_op_parallelism_threads()); - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, LocalService::NewService(service_options)); - instance->client = MakeUnique(instance->service.get()); + instance->client = absl::make_unique(instance->service.get()); LocalClient* cl = instance->client.get(); client_library.local_instances_.insert( @@ -134,10 +135,11 @@ ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { return it->second->client.get(); } - auto instance = MakeUnique(); + auto instance = absl::make_unique(); TF_ASSIGN_OR_RETURN(instance->service, CompileOnlyService::NewService(platform)); - instance->client = MakeUnique(instance->service.get()); + instance->client = + absl::make_unique(instance->service.get()); CompileOnlyClient* cl = instance->client.get(); client_library.compile_only_instances_.insert( diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 5c9abad4c3126be5e45e96c770c0679fe8606788..a6c58cb17571b63cd0f45d0d95376a02bc4a72e2 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -15,15 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/client/compile_only_client.h" +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector service_instances; @@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime( metadata); } -int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { +int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) { llvm::Triple llvm_triple( llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); if (llvm_triple.isArch64Bit()) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index a551edeab0943ec5213c5cb035644c02c3cf54d7..9e3ed23734941d98d622c38028cd44d48d3e620a 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -52,12 +52,12 @@ class CompileOnlyClient : public Client { // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + static int64 PointerSizeForTriple(absl::string_view triple); private: CompileOnlyService* compiler_service_; diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 7dee41f6a05025ec196b78e54015e8e71777031f..0f1745366b7c33e573aff2e66d85431b01488c49 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -59,10 +59,10 @@ string ExecutableBuildOptions::ToString() const { if (generate_hlo_graph_.has_value()) { generate_hlo_graph = generate_hlo_graph_.value(); } - return tensorflow::strings::Printf( + return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); + device_ordinal_, result_layout, generate_hlo_graph); } ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( @@ -71,41 +71,41 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( return *this; } -const tensorflow::gtl::optional& -ExecutableBuildOptions::generate_hlo_graph() const { +const absl::optional& ExecutableBuildOptions::generate_hlo_graph() + const { return generate_hlo_graph_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_optimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_optimized_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_unoptimized_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { return dump_unoptimized_hlo_proto_to_; } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_per_pass_hlo_proto_to_ = string(dirpath); return *this; } -const tensorflow::gtl::optional& +const absl::optional& ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const { return dump_per_pass_hlo_proto_to_; } @@ -115,7 +115,7 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) { return *this; } -tensorflow::gtl::optional ExecutableBuildOptions::hlo_profile() const { +absl::optional ExecutableBuildOptions::hlo_profile() const { return hlo_profile_; } diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 9dc9be4423564fb967b247c2d1df31099cb80237..93334db88bc24f2ffbf3c7a57ee45ef238286739 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -57,37 +57,36 @@ class ExecutableBuildOptions { // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). ExecutableBuildOptions& set_generate_hlo_graph(string regex); - const tensorflow::gtl::optional& generate_hlo_graph() const; + const absl::optional& generate_hlo_graph() const; // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + absl::string_view dirpath); + const absl::optional& dump_optimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() - const; + absl::string_view dirpath); + const absl::optional& dump_unoptimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath); - const tensorflow::gtl::optional& dump_per_pass_hlo_proto_to() const; + absl::string_view dirpath); + const absl::optional& dump_per_pass_hlo_proto_to() const; // If true, specifies that we should record an HLO profile during execution // and log it after execution (as in DebugOptions). If nullopt the default is // used. ExecutableBuildOptions& set_hlo_profile(bool enabled); - tensorflow::gtl::optional hlo_profile() const; + absl::optional hlo_profile() const; - void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } - const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + const absl::Span disabled_hlo_passes() const { return disabled_hlo_passes_; } @@ -96,14 +95,14 @@ class ExecutableBuildOptions { string ToString() const; private: - tensorflow::gtl::optional hlo_profile_; + absl::optional hlo_profile_; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - tensorflow::gtl::optional generate_hlo_graph_; - tensorflow::gtl::optional dump_optimized_hlo_proto_to_; - tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; - tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; + absl::optional generate_hlo_graph_; + absl::optional dump_optimized_hlo_proto_to_; + absl::optional dump_unoptimized_hlo_proto_to_; + absl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; std::vector disabled_hlo_passes_; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index a2f32ab97eab10294a607f35fc79ded1cc2c5792..a18c94c4e695a6cdcb9dcc60b64b617cecd276d8 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -31,7 +31,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -64,6 +64,17 @@ xla_test( ], ) +cc_library( + name = "conv_grad_size_util", + srcs = ["conv_grad_size_util.cc"], + hdrs = ["conv_grad_size_util.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/core:lib", + ], +) + cc_library( name = "math", srcs = ["math.cc"], @@ -102,7 +113,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -128,9 +139,9 @@ cc_library( deps = [ ":arithmetic", ":constants", - "//tensorflow/compiler/tf2xla/lib:util", + ":conv_grad_size_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -142,6 +153,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -209,5 +221,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 9225b1acd69c214d6f08a45372a8082ed789c18c..e86c10f030f3990d67e5a6638100640f73c82307 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, b = builder->CreateSubBuilder(name); } else { b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + absl::StrCat(name, "_", PrimitiveType_Name(type))); } const Shape scalar = ShapeUtil::MakeShape(type, {}); diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 031d62e4ffef188082303a28866bbc72a154e9b1..1ada7b4a964ccf7ca400b937abbe425bef083468 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -56,7 +56,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { std::numeric_limits::epsilon()); default: return builder->ReportError(InvalidArgument( - "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178..81624614c1e3599dfe116eb61d9e2edcd5230684 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -37,13 +37,13 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { primitive_util::IsComplexType(type))) { return builder->ReportError(InvalidArgument( "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } if (std::is_same::value && !primitive_util::IsComplexType(type)) { return builder->ReportError(InvalidArgument( "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } switch (type) { case F16: @@ -71,7 +71,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { default: return builder->ReportError( InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4c50a5491803bc62d2de758177f8f5d050f441d --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +StatusOr GetWindowedOutputSize( + int64 input_size, int64 filter_size, int64 dilation_rate, int64 stride, + Padding padding_type) { + if (stride <= 0) { + return tensorflow::errors::InvalidArgument("Stride must be > 0, but got ", + stride); + } + if (dilation_rate < 1) { + return tensorflow::errors::InvalidArgument( + "Dilation rate must be >= 1, but got ", dilation_rate); + } + + int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; + SpatialDimensionOutputSizeAndPadding dim; + switch (padding_type) { + case Padding::kValid: + dim.output_size = (input_size - effective_filter_size + stride) / stride; + dim.pad_before = dim.pad_after = 0; + break; + case Padding::kSame: + dim.output_size = (input_size + stride - 1) / stride; + const int64 padding_needed = + std::max(int64{0}, (dim.output_size - 1) * stride + + effective_filter_size - input_size); + // For odd values of total padding, add more padding on the "after" side + // of the given dimension. + dim.pad_before = padding_needed / 2; + dim.pad_after = padding_needed - dim.pad_before; + break; + } + if (dim.output_size < 0) { + return tensorflow::errors::InvalidArgument( + "Computed output size would be negative: ", dim.output_size, + " [input_size: ", input_size, + ", effective_filter_size: ", effective_filter_size, + ", stride: ", stride, "]"); + } + return dim; +} + +} // namespace + +StatusOr +ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size, + int64 output_size, int64 dilation, + int64 stride, Padding padding) { + TF_ASSIGN_OR_RETURN(SpatialDimensionOutputSizeAndPadding output_dim, + GetWindowedOutputSize(input_size, filter_size, dilation, + stride, padding)); + if (output_size != output_dim.output_size) { + return tensorflow::errors::InvalidArgument( + "Size of out_backprop doesn't match computed: ", "actual = ", + output_size, ", computed = ", output_dim.output_size, + " input: ", input_size, " filter: ", filter_size, + " output: ", output_size, " stride: ", stride, " dilation: ", dilation); + } + + SpatialDimensionOutputSizeAndPadding dim; + int64 effective_filter_size = (filter_size - 1) * dilation + 1; + dim.output_size = (output_dim.output_size - 1) * stride + 1; + const auto padded_out_size = input_size + effective_filter_size - 1; + dim.pad_before = effective_filter_size - 1 - output_dim.pad_before; + dim.pad_after = padded_out_size - dim.output_size - dim.pad_before; + VLOG(2) << "expanded_out = " << dim.output_size + << ", effective_filter_size = " << effective_filter_size + << ", padded_out = " << padded_out_size + << ", pad_before = " << dim.pad_before + << ", pad_after = " << dim.pad_after << ", dilation = " << dilation + << ", strides = " << stride; + return dim; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0ad01728e6e828240b9ac4b948777e5d970d09e0 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/conv_grad_size_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ + +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Information about a single spatial dimension for a convolution gradients and +// windowed operations. +struct SpatialDimensionOutputSizeAndPadding { + // Effective size of the operation output (potentially expanded). + int64 output_size; + // Number of padding elements to be added before/after this dimension of + // the input when computing the input gradient. + int64 pad_before; + int64 pad_after; +}; + +// Verifies that the dimensions all match, and computes the size and padding of +// a spatial dimension for convolution gradient operations. +StatusOr +ConvGradExtractAndVerifyDimension(int64 input_size, int64 filter_size, + int64 output_size, int64 dilation, + int64 stride, Padding padding); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 0221de7672c7b7c02b1f8b9c7ff4f92151e567c6..d3d7edb42a38595bbf9fdb36e0dd946ae5df51f9 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -69,8 +69,7 @@ std::array kErfUCoefficient = { // Evaluate the polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients) { +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { XlaOp poly = ScalarLike(x, 0.0); for (float c : coefficients) { poly = poly * x + ScalarLike(x, c); @@ -207,7 +206,11 @@ XlaOp Lgamma(XlaOp input) { XlaOp log_y = log_sqrt_two_pi + (z + one_half) * log_t - t + Log(x); - XlaOp reflection = log_pi - Log(Sin(pi * input)) - log_y; + // If z = a + 0j, the analytic continuation of log reduces to taking the + // absolute value of the real part. + // Re(log(z)) = Re(log|z| + arg(z)j) + // = log|a| + XlaOp reflection = log_pi - Log(Abs(Sin(pi * input))) - log_y; XlaOp result = Select(need_to_reflect, reflection, log_y); return result; } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 13db2325569cf2e25e3ff1200adf4b2544dc2f73..a6cafd42077367bf23ffa1f45eab31c01dc31b16 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -34,8 +34,7 @@ XlaOp Reciprocal(XlaOp operand); // Evaluates a polynomial given coefficients and `x`. // N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, - tensorflow::gtl::ArraySlice coefficients); +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 1c91237ae1574f92cda78c9bddc6f4ac1d68f47c..377654220b5df4487e9e194361473d54ff46a54e 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -16,61 +16,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { -namespace { - -template -XlaOp MakeIota(XlaBuilder* builder, int64 size) { - std::vector values(size); - for (int64 i = 0; i < size; ++i) { - values[i] = static_cast(i); - } - return ConstantR1(builder, values); -} - -} // namespace - -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { - switch (type) { - case S8: - return MakeIota(builder, size); - case S16: - return MakeIota(builder, size); - case S32: - return MakeIota(builder, size); - case S64: - return MakeIota(builder, size); - case U8: - return MakeIota(builder, size); - case U16: - return MakeIota(builder, size); - case U32: - return MakeIota(builder, size); - case U64: - return MakeIota(builder, size); - case BF16: - return MakeIota(builder, size); - case F16: - return MakeIota(builder, size); - case F32: - return MakeIota(builder, size); - case F64: - return MakeIota(builder, size); - case C64: - return MakeIota(builder, size); - default: - return builder->ReportError( - InvalidArgument("Unimplemented type for Iota: %s.", - PrimitiveType_Name(type).c_str())); - } -} - XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n) { auto a = Iota(builder, type, m); @@ -87,8 +39,8 @@ XlaOp GetMatrixDiagonal(XlaOp x) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0}); @@ -114,8 +66,8 @@ XlaOp Triangle(XlaOp x, bool lower) { TF_RET_CHECK(n_dims >= 2); const int64 m = shape.dimensions(n_dims - 2); const int64 n = shape.dimensions(n_dims - 1); - tensorflow::gtl::ArraySlice major_dims( - AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - 2); + absl::Span major_dims = + AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2); auto a = Iota(builder, U32, n); auto b = Iota(builder, U32, m); xla::XlaOp indicator; diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc index 8a96ec68d2dca8485215258b1f6731b934e6f2a8..7d6aedd49462bd4f075f90d0b0f85c40f1191aa1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric_test.cc +++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc @@ -30,16 +30,6 @@ class NumericTest : public ClientLibraryTestBase { void TestMatrixDiagonal(); }; -// TODO(b/64798317): Delete this test case once xla::IotaGen is converted to -// xla::Iota. This test is already implemented for xla::IotaGen in -// xla/tests/iota_test.cc. -XLA_TEST_F(NumericTest, Iota) { - XlaBuilder builder(TestName()); - Iota(&builder, S32, 10); - - ComputeAndCompareR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {}); -} - XLA_TEST_F(NumericTest, Triangle) { XlaBuilder builder(TestName()); Array3D input(2, 3, 4); diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc index 7199269a6c889f3589c1148687faf0bb2aaae90a..1979c867a4c3be438f8b997c566799fe84b43053 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.cc +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/pooling.h" -#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" namespace xla { @@ -26,11 +26,9 @@ namespace { // element of an image by the count of elements that contributed to that // element during pooling. XlaOp AvgPoolDivideByCountWithGeneralPadding( - XlaOp sums, PrimitiveType dtype, - tensorflow::gtl::ArraySlice input_shape, - tensorflow::gtl::ArraySlice> spatial_padding, - tensorflow::gtl::ArraySlice ksize, - tensorflow::gtl::ArraySlice stride, + XlaOp sums, PrimitiveType dtype, absl::Span input_shape, + absl::Span> spatial_padding, + absl::Span ksize, absl::Span stride, const TensorFormat& data_format) { // The padding shouldn't be included in the counts. We use another // ReduceWindow to find the right counts. @@ -73,8 +71,8 @@ XlaOp AvgPoolDivideByCountWithGeneralPadding( // Sums all elements in the window specified by 'kernel_size' and 'stride'. XlaOp ComputeSums(XlaOp operand, XlaOp init_value, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, + absl::Span kernel_size, + absl::Span stride, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -89,11 +87,9 @@ XlaOp ComputeSums(XlaOp operand, XlaOp init_value, // Creates a padding configuration out of spatial padding values. PaddingConfig MakeSpatialPaddingConfig( - tensorflow::gtl::ArraySlice> spatial_padding, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, + absl::Span> spatial_padding, + int num_spatial_dims, absl::Span stride, const TensorFormat& data_format) { - const int num_spatial_dims = kernel_size.size() - 2; PaddingConfig padding_config; for (int i = 0; i < 2 + num_spatial_dims; ++i) { padding_config.add_dimensions(); @@ -109,10 +105,33 @@ PaddingConfig MakeSpatialPaddingConfig( return padding_config; } +XlaOp AvgPoolDivideByCount(XlaOp pooled, absl::Span input_size, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + PrimitiveType dtype, const TensorFormat& data_format, + bool counts_include_padding) { + if (counts_include_padding) { + // If counts include padding, all windows have the same number of elements + // contributing to each average. Divide by the window size everywhere to get + // the average. + int64 window_size = + std::accumulate(window_dimensions.begin(), window_dimensions.end(), 1, + [](int64 a, int64 b) { return a * b; }); + auto divisor = ConstantR0WithType(pooled.builder(), dtype, window_size); + + return pooled / divisor; + } else { + return AvgPoolDivideByCountWithGeneralPadding(pooled, dtype, input_size, + padding, window_dimensions, + window_strides, data_format); + } +} + } // namespace -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { XlaBuilder* b = operand.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { @@ -125,9 +144,9 @@ XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, }); } -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding) { XlaBuilder* b = operand.builder(); @@ -137,32 +156,22 @@ XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, auto init_value = Zero(b, dtype); std::vector input_size(operand_shape.dimensions().begin(), operand_shape.dimensions().end()); - auto padding_config = - MakeSpatialPaddingConfig(padding, kernel_size, stride, data_format); + const int num_dims = kernel_size.size(); + const int num_spatial_dims = num_dims - 2; + auto padding_config = MakeSpatialPaddingConfig(padding, num_spatial_dims, + stride, data_format); auto padded_operand = Pad(operand, Zero(b, dtype), padding_config); auto pooled = ComputeSums(padded_operand, init_value, kernel_size, stride, data_format); - if (counts_include_padding) { - // If counts include padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = - std::accumulate(kernel_size.begin(), kernel_size.end(), 1, - [](int64 x, int64 y) { return x * y; }); - - auto divisor = ConstantR0WithType(b, dtype, window_size); - return pooled / divisor; - } else { - return AvgPoolDivideByCountWithGeneralPadding( - pooled, dtype, input_size, padding, kernel_size, stride, data_format); - } + return AvgPoolDivideByCount(pooled, input_size, kernel_size, stride, + padding, dtype, data_format, + counts_include_padding); }); } std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format) { const int num_spatial_dims = kernel_size.size() - 2; std::vector input_spatial_dimensions; @@ -180,4 +189,101 @@ std::vector> MakeSpatialPadding( stride_spatial_dimensions, padding); } +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding) { + XlaBuilder* b = out_backprop.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + const int num_dims = kernel_size.size(); + + if (gradients_size.size() != num_dims) { + return tensorflow::errors::InvalidArgument("gradients must be ", num_dims, + "-dimensional"); + } + + TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape, + b->GetShape(out_backprop)); + if (out_backprop_xla_shape.dimensions().size() != num_dims) { + return tensorflow::errors::InvalidArgument("out_backprop must be ", + num_dims, "-dimensional"); + } + + // We can think of average-pooling as: + // * a convolution with a kernel consisting entirely of 1s, where the + // input feature and output feature are equal, and 0s everywhere else. + // * followed by dividing by the counts. + // + // This then gives us an algorithm to build the gradient: + // * divide out_backprop by the counts, followed by + // * Conv2DBackpropInput specialized for that kernel, which simplifies to + // a Pad and a ReduceWindow. + // + // For an explanation of backpropagation for convolution, see the comments + // in third_party/tensorflow/core/kernels/conv_grad_ops.h + + // TF filter shape is [ H, W, ..., inC, outC ] + + // The input gradients are computed by a convolution of the output gradients + // and the filter, with some appropriate padding. See the comment at the top + // of conv_grad_ops.h for details. + PrimitiveType dtype = out_backprop_xla_shape.element_type(); + auto out_backprop_div = AvgPoolDivideByCount( + out_backprop, gradients_size, kernel_size, stride, spatial_padding, + dtype, data_format, counts_include_padding); + + // Pad the gradients in the spatial dimensions. We use the same padding + // as Conv2DBackpropInput. + PaddingConfig padding_config = MakeNoPaddingConfig(num_dims); + std::vector padded_gradients_size(gradients_size.begin(), + gradients_size.end()); + // First, pad the output gradients the same way as the input. The additional + // padding will be removed as a last step before returning the input + // gradients. + const int num_spatial_dims = num_dims - 2; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + padded_gradients_size[dim] += + (spatial_padding[i].first + spatial_padding[i].second); + } + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = data_format.spatial_dimension(i); + TF_ASSIGN_OR_RETURN( + SpatialDimensionOutputSizeAndPadding conv_backprop_spatial_dim, + ConvGradExtractAndVerifyDimension( + /*input_size=*/padded_gradients_size[dim], + /*filter_size=*/kernel_size[dim], + /*output_size=*/out_backprop_xla_shape.dimensions(dim), + /*dilation=*/1, + /*stride=*/stride[dim], /*padding=*/Padding::kValid)); + auto* padding = padding_config.mutable_dimensions(dim); + padding->set_edge_padding_low(conv_backprop_spatial_dim.pad_before); + padding->set_edge_padding_high(conv_backprop_spatial_dim.pad_after); + padding->set_interior_padding(stride[dim] - 1); + } + + auto zero = Zero(b, dtype); + auto padded_gradients = Pad(out_backprop_div, zero, padding_config); + + // in_backprop = padded_gradients ones + std::vector ones(num_dims, 1LL); + auto in_backprop = + ReduceWindow(padded_gradients, Zero(b, dtype), + CreateScalarAddComputation(dtype, b), kernel_size, + /*window_strides=*/ones, Padding::kValid); + // The input padding doesn't contribute to the gradient, remove it. + std::vector> neg_spatial_padding; + neg_spatial_padding.reserve(spatial_padding.size()); + for (const std::pair& spatial_padding_dim : spatial_padding) { + neg_spatial_padding.emplace_back(-spatial_padding_dim.first, + -spatial_padding_dim.second); + } + auto remove_padding_config = MakeSpatialPaddingConfig( + neg_spatial_padding, num_spatial_dims, stride, data_format); + return Pad(in_backprop, zero, remove_padding_config); + }); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/pooling.h b/tensorflow/compiler/xla/client/lib/pooling.h index 1699c585d3b09a306c21cfa797a9023a8463bd1f..5c0054857d072dc7f36e259a29b9b24fd70796ac 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.h +++ b/tensorflow/compiler/xla/client/lib/pooling.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -25,7 +25,7 @@ namespace xla { class TensorFormat { public: TensorFormat(int batch_dimension, int feature_dimension, - tensorflow::gtl::ArraySlice spatial_dimensions) + absl::Span spatial_dimensions) : batch_dimension_(batch_dimension), feature_dimension_(feature_dimension), spatial_dimensions_(spatial_dimensions.begin(), @@ -45,29 +45,36 @@ class TensorFormat { // The number of the dimension that represents the features. int feature_dimension_; // The dimension numbers for the spatial dimensions. - tensorflow::gtl::InlinedVector spatial_dimensions_; + absl::InlinedVector spatial_dimensions_; }; // Computes the max pool of 'operand'. -XlaOp MaxPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); // Computes the average pool of 'operand'. -XlaOp AvgPool(XlaOp operand, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, - tensorflow::gtl::ArraySlice> padding, +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, const TensorFormat& data_format, const bool counts_include_padding); // Returns the list of low and high padding elements in each spatial dimension // for the given 'padding' specification. std::vector> MakeSpatialPadding( - tensorflow::gtl::ArraySlice input_size, - tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, const TensorFormat& data_format); +// Computes the average pool gradient. +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, + const bool counts_include_padding); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_POOLING_H_ diff --git a/tensorflow/compiler/xla/client/lib/pooling_test.cc b/tensorflow/compiler/xla/client/lib/pooling_test.cc index 4b4553b60db555ad7c2ab6b695236df745e30683..30adb9b1ad7fa03b40ce3802a2172680b60a9ad7 100644 --- a/tensorflow/compiler/xla/client/lib/pooling_test.cc +++ b/tensorflow/compiler/xla/client/lib/pooling_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/client/lib/pooling.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -22,7 +23,7 @@ namespace xla { namespace { TensorFormat MakeNCHWFormat(int num_spatial_dims) { - tensorflow::gtl::InlinedVector spatial_dimensions; + absl::InlinedVector spatial_dimensions; for (int i = 0; i < num_spatial_dims; ++i) { spatial_dimensions.push_back(i + 2); } @@ -31,8 +32,8 @@ TensorFormat MakeNCHWFormat(int num_spatial_dims) { } std::vector> MakeGeneralPadding( - XlaOp input, tensorflow::gtl::ArraySlice kernel_size, - tensorflow::gtl::ArraySlice stride, Padding padding, + XlaOp input, absl::Span kernel_size, + absl::Span stride, Padding padding, const xla::TensorFormat& data_format) { XlaBuilder* b = input.builder(); Shape operand_shape = b->GetShape(input).ValueOrDie(); @@ -45,7 +46,7 @@ std::vector> MakeGeneralPadding( // Add singleton batch and feature dimensions to spatial dimensions, according // to 'data_format' specification. std::vector ExpandWithBatchAndFeatureDimensions( - tensorflow::gtl::ArraySlice spatial_dim_sizes, + absl::Span spatial_dim_sizes, const xla::TensorFormat& data_format) { const int num_spatial_dims = spatial_dim_sizes.size(); std::vector tensor_sizes(num_spatial_dims + 2, 1); @@ -181,5 +182,109 @@ XLA_TEST_F(PoolingTest, error_spec_); } +XLA_TEST_F(PoolingTest, AvgPool2DGradNoPadding) { + XlaBuilder builder(TestName()); + for (bool counts_include_padding : {false, true}) { + XlaOp out_backprop = ConstantR4FromArray4D(&builder, {{{{1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, + {{0, 0}, {0, 0}}, MakeNCHWFormat(2), + /*counts_include_padding=*/counts_include_padding); + // Without padding, counts_include_padding makes no difference. + ComputeAndCompareR4( + &builder, {{{{0.25, 0.25, 0.}, {0.25, 0.25, 0.}, {0., 0., 0.}}}}, {}, + error_spec_); + } +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradNoPaddingWithStride) { + XlaBuilder builder(TestName()); + for (bool counts_include_padding : {false, true}) { + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, + {{0, 0}, {0, 0}}, MakeNCHWFormat(2), + /*counts_include_padding=*/counts_include_padding); + // Without padding, counts_include_padding makes no difference. + ComputeAndCompareR4( + &builder, {{{{0.25, 0.5, 0.25}, {0.5, 1., 0.5}, {0.25, 0.5, 0.25}}}}, + {}, error_spec_); + } +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), + /*counts_include_padding=*/true); + ComputeAndCompareR4( + &builder, + {{{{0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}, {0.25, 0.25, 0.25}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1.}, {1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), false); + ComputeAndCompareR4( + &builder, {{{{1., 0.5, 0.5}, {0.5, 0.25, 0.25}, {0.5, 0.25, 0.25}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, AvgPool2DGradWithPaddingCountWithStride) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), true); + ComputeAndCompareR4(&builder, + {{{{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}}}, {}, + error_spec_); +} + +XLA_TEST_F(PoolingTest, + AvgPool2DGradWithPaddingCountWithStrideNotIncludePadding) { + XlaBuilder builder(TestName()); + + XlaOp out_backprop = + ConstantR4FromArray4D(&builder, {{{{1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}, + {1., 1., 1., 1.}}}}); + auto data_format = MakeNCHWFormat(2); + auto kernel_size = ExpandWithBatchAndFeatureDimensions({2, 2}, data_format); + auto stride = ExpandWithBatchAndFeatureDimensions({1, 1}, data_format); + AvgPoolGrad(out_backprop, {1, 1, 3, 3}, kernel_size, stride, {{1, 1}, {1, 1}}, + MakeNCHWFormat(2), false); + ComputeAndCompareR4( + &builder, {{{{2.25, 1.5, 2.25}, {1.5, 1., 1.5}, {2.25, 1.5, 2.25}}}}, {}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 081fec7ad92958aa285e4be41394d7b1876e0815..6861521acc0db1d640666a6793b898a183ab6a17 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - XlaBuilder b( - tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); + XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index cffb24e29beda6a8c40dca2fe709be22892dd489..4402ba8762c1538951c326c880fc3b6dd63ef0c6 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/source_map_util.h" @@ -51,7 +51,7 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, } Status LocalExecutable::ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -59,7 +59,7 @@ Status LocalExecutable::ValidateExecutionOptions( // Check argument number, shapes, and layouts. if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "invalid number of arguments for computation: expected %d, got %u", computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); + ShapeUtil::HumanString( + computation_layout.parameter_layout(i).shape()), + ShapeUtil::HumanString(arguments[i]->on_host_shape())); } } @@ -88,8 +88,7 @@ Status LocalExecutable::ValidateExecutionOptions( if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", - stream_platform->Name().c_str(), - backend_->platform()->Name().c_str()); + stream_platform->Name(), backend_->platform()->Name()); } // Cannot specify device_ordinal with a stream. The stream determines these @@ -120,10 +119,10 @@ Status LocalExecutable::ValidateExecutionOptions( return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal()).c_str(), - build_executor->GetDeviceDescription().name().c_str(), - backend_->device_name(run_device_ordinal).c_str(), - run_executor->GetDeviceDescription().name().c_str()); + backend_->device_name(build_device_ordinal()), + build_executor->GetDeviceDescription().name(), + backend_->device_name(run_device_ordinal), + run_executor->GetDeviceDescription().name()); } if (!run_options.allocator()) { @@ -133,15 +132,15 @@ Status LocalExecutable::ValidateExecutionOptions( if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - run_options.allocator()->platform()->Name().c_str(), - backend.platform()->Name().c_str()); + run_options.allocator()->platform()->Name(), + backend.platform()->Name()); } return Status::OK(); } StatusOr LocalExecutable::Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( ValidateExecutionOptions(arguments, run_options, *backend_)); @@ -178,7 +177,7 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments) { + const absl::Span arguments) { executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); @@ -192,7 +191,7 @@ StatusOr LocalExecutable::ExecuteAndDump( } Status LocalExecutable::RecordArguments( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { @@ -246,7 +245,7 @@ Backend* LocalClient::mutable_backend() { StatusOr> LocalClient::Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options) { ExecutableBuildOptions updated_options = options; if (options.device_ordinal() == -1) { @@ -257,9 +256,9 @@ StatusOr> LocalClient::Compile( TF_ASSIGN_OR_RETURN(std::unique_ptr executable, local_service_->CompileExecutable( computation, argument_layouts, updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); + return absl::WrapUnique(new LocalExecutable(std::move(executable), + local_service_->mutable_backend(), + updated_options)); } StatusOr LocalClient::LiteralToShapedBuffer( diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index ae23809261757c637ab4aec036750c371ac60cdc..56c3a3da023ebf124b4bd91c2c608d0cd00a2381 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -40,7 +40,7 @@ class LocalExecutable { // Run the compiled computation with the given arguments and options and // return the result. StatusOr Run( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, ExecutableRunOptions run_options); // Return the options used to build the executable. @@ -63,7 +63,7 @@ class LocalExecutable { // The given ExecutableRunOptions override any values from legacy_flags // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to @@ -73,13 +73,12 @@ class LocalExecutable { // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, - const tensorflow::gtl::ArraySlice arguments); + const absl::Span arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - HloSnapshot* hlo_snapshot); + Status RecordArguments(const absl::Span arguments, + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); @@ -120,7 +119,7 @@ class LocalClient : public Client { // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& options); // Copy the literal data to the device with the given ordinal and return as a diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f..992b13139c480900e7b983825be61ce88f14e11b 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -23,16 +23,15 @@ limitations under the License. namespace xla { -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides) { +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides) { bool ok = input_dimensions.size() == window_dimensions.size() && input_dimensions.size() == window_strides.size(); if (!ok) { return InvalidArgument( - "Want input dimensions size %zu = window dimensions size %zu = window " - "strides size %zu", + "Want input dimensions size %u = window dimensions size %u = window " + "strides size %u", input_dimensions.size(), window_dimensions.size(), window_strides.size()); } @@ -40,9 +39,9 @@ Status ValidatePaddingValues( } std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, window_strides)); std::vector> low_high_padding; diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index e23b0b3a90a091bf80973525810793c3eda4a036..5c009bd49e48b158550a32e64b0d63e2840dd1a9 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -41,10 +41,9 @@ enum class Padding { // Validates that the slices are acceptable for determining padding -- this can // be used to check the preconditions of MakePadding below to produce an error // message that can be returned to the user. -Status ValidatePaddingValues( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides); +Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. @@ -58,9 +57,9 @@ Status ValidatePaddingValues( // window_dimensions, and strides must match, which is equal to the number // of elements in the result vector. std::vector> MakePadding( - tensorflow::gtl::ArraySlice input_dimensions, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/client/sharding_builder.h b/tensorflow/compiler/xla/client/sharding_builder.h index 34763e54d946690289ff42a7712b980168933eee..59df3a8762c755848982bc8e2590de968ed2adb6 100644 --- a/tensorflow/compiler/xla/client/sharding_builder.h +++ b/tensorflow/compiler/xla/client/sharding_builder.h @@ -56,4 +56,4 @@ OpSharding Tuple(const ShapeTree& shardings); } // namespace sharding_builder } // namespace xla -#endif +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index b3b00e2fffe1196b36190ec72d1425bae4e4e276..e639028ccda11ae7e873f601c2f95749bce178c0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -21,19 +21,24 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; namespace { @@ -67,7 +72,7 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { if (!ShapeUtil::ElementIsIntegral(shape)) { return InvalidArgument( "Argument to >> operator does not have an integral type (%s).", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (ShapeUtil::ElementIsSigned(shape)) { return ShiftRightArithmetic(x, y); @@ -85,7 +90,7 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { } StatusOr> XlaBuilder::GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const { + absl::Span operands) const { std::vector operand_shapes; for (const XlaOp& operand : operands) { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); @@ -194,7 +199,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, // TODO(b/33009255): Implmement constant folding for cross replica sum. case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kHostCompute: case HloOpcode::kCall: // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op @@ -221,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { auto build_status = Build(); if (!build_status.ok()) { parent_builder_->ReportError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); + AddStatus(build_status.status(), absl::StrCat("error from: ", name_))); return {}; } return build_status.ConsumeValueOrDie(); @@ -288,7 +291,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { StatusOr XlaBuilder::InDimBroadcast( const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -349,9 +352,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { }); } -XlaOp XlaBuilder::BinaryOp( - HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -445,12 +447,12 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } @@ -463,14 +465,27 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.add_dimensions(iota_dimension); + return AddInstruction(std::move(instr), HloOpcode::kIota); + }); +} + +XlaOp XlaBuilder::Iota(PrimitiveType type, int64 size) { + return Iota(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +} + XlaOp XlaBuilder::Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -489,7 +504,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { - return InvalidArgument("parameter %lld already registered", + return InvalidArgument("parameter %d already registered", parameter_number); } instr.set_parameter_number(parameter_number); @@ -499,8 +514,8 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, }); } -XlaOp XlaBuilder::Broadcast( - const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp XlaBuilder::Broadcast(const XlaOp& operand, + absl::Span broadcast_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( @@ -525,7 +540,7 @@ XlaOp XlaBuilder::Broadcast( XlaOp XlaBuilder::BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { + const absl::Span broadcast_dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { return InDimBroadcast(shape, operand, broadcast_dimensions); }); @@ -540,9 +555,9 @@ StatusOr XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) { } XlaOp XlaBuilder::Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -577,7 +592,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index, } XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -615,15 +630,15 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, }); } -XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension)); @@ -655,8 +670,8 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span dimensions, + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, @@ -670,7 +685,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { + absl::Span new_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); @@ -680,7 +695,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, } XlaOp XlaBuilder::Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { // Not collapsing anything, trivially we can return the operand versus @@ -690,8 +705,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, // Out-of-order collapse is not supported. // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { + for (absl::Span::size_type i = 1; i < dimensions.size(); ++i) { if (dimensions[i] - 1 != dimensions[i - 1]) { return InvalidArgument( "Collapsed dimensions are not in consecutive order."); @@ -703,8 +717,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); + VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { @@ -715,8 +728,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } } - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; + VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; return Reshape(operand, new_sizes); }); @@ -744,13 +756,13 @@ XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, }); } -XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice elements) { +XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); @@ -765,7 +777,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { if (!ShapeUtil::IsTuple(tuple_shape)) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(tuple_shape).c_str()); + ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = ShapeUtil::GetTupleElementShape(tuple_shape, index); @@ -778,36 +790,37 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { +XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -815,12 +828,14 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) { dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); }); } -XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { +XlaOp XlaBuilder::DotGeneral( + const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -829,6 +844,9 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; + if (precision_config_proto != nullptr) { + *instr.mutable_precision_config() = *precision_config_proto; + } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); } @@ -840,16 +858,14 @@ Status XlaBuilder::VerifyConvolution( return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_dims = ShapeUtil::Rank(lhs_shape); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_spatial_dims = num_dims - 2; @@ -863,7 +879,7 @@ Status XlaBuilder::VerifyConvolution( } for (int i = 0; i < numbers.size(); ++i) { if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + return InvalidArgument("Convolution %s[%d] is out of bounds: %d", field_name, i, numbers.Get(i)); } } @@ -881,25 +897,30 @@ Status XlaBuilder::VerifyConvolution( } XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count, precision_config_proto); } XlaOp XlaBuilder::ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); + CreateDefaultConvDimensionNumbers(window_strides.size()), + feature_group_count, precision_config_proto); } XlaOp XlaBuilder::ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -926,26 +947,29 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( return ConvGeneral(lhs, rhs, window_strides, MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config_proto); }); } XlaOp XlaBuilder::ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config_proto); } XlaOp XlaBuilder::ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -964,12 +988,17 @@ XlaOp XlaBuilder::ConvGeneralDilated( MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, rhs_dilation)); - TF_ASSIGN_OR_RETURN( - *instr.mutable_shape(), - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(), - dimension_numbers)); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, instr.window(), + dimension_numbers, feature_group_count)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; + instr.set_feature_group_count(feature_group_count); + + if (precision_config_proto != nullptr) { + *instr.mutable_precision_config() = *precision_config_proto; + } return AddInstruction(std::move(instr), HloOpcode::kConvolution, {lhs, rhs}); @@ -977,22 +1006,21 @@ XlaOp XlaBuilder::ConvGeneralDilated( } StatusOr XlaBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const { const auto verify_size = [&](const size_t x, const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return Status::OK(); } else { return InvalidArgument( - "%s", tensorflow::strings::StrCat( + "%s", absl::StrCat( "Window has different number of window dimensions than of ", x_name, "\nNumber of window dimensions: ", window_dimensions.size(), - "\nNumber of ", x_name, ": ", x, "\n") - .c_str()); + "\nNumber of ", x_name, ": ", x, "\n")); } }; TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); @@ -1032,7 +1060,7 @@ StatusOr XlaBuilder::MakeWindow( } XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1073,6 +1101,23 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { "Replicated sharding is not yet supported for infeeds"); } + // Infeed takes a single token operand. Generate the token to pass to the + // infeed. + XlaOp token; + auto make_token = [&]() { + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {}); + }; + if (sharding()) { + // Arbitrarily assign token to device 0. + OpSharding sharding = sharding_builder::AssignDevice(0); + XlaScopedShardingAssignment scoped_sharding(this, sharding); + TF_ASSIGN_OR_RETURN(token, make_token()); + } else { + TF_ASSIGN_OR_RETURN(token, make_token()); + } + // The sharding is set by the client according to the data tuple shape. // However, the shape of the infeed instruction is a tuple containing the // data and a token. For tuple sharding type, the sharding must be changed @@ -1088,11 +1133,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { sharding_builder::AssignDevice(0); XlaScopedShardingAssignment scoped_sharding(this, infeed_instruction_sharding); - TF_ASSIGN_OR_RETURN( - infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } else { - TF_ASSIGN_OR_RETURN( - infeed, AddInstruction(std::move(instr), HloOpcode::kInfeed, {})); + TF_ASSIGN_OR_RETURN(infeed, AddInstruction(std::move(instr), + HloOpcode::kInfeed, {token})); } // The infeed instruction produces a tuple of the infed data and a token @@ -1151,15 +1196,22 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; instr.set_outfeed_config(outfeed_config); + // Outfeed takes a token as its second operand. Generate the token to pass + // to the outfeed. + HloInstructionProto token_instr; + *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape(); + TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr), + HloOpcode::kAfterAll, {})); + TF_RETURN_IF_ERROR( - AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand}) + AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand, token}) .status()); // The outfeed instruction produces a token. However, existing users expect @@ -1197,8 +1249,8 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1217,7 +1269,7 @@ XlaOp XlaBuilder::CreateToken() { }); } -XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { +XlaOp XlaBuilder::AfterAll(absl::Span tokens) { return ReportErrorOrReturn([&]() -> StatusOr { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); @@ -1229,15 +1281,15 @@ XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice tokens) { } XlaOp XlaBuilder::CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", - call_target_name.c_str()); + call_target_name); } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); @@ -1245,21 +1297,8 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, }); } -XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice operands, - const string& channel_name, - int64 cost_estimate_ns, const Shape& shape) { - return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - *instr.mutable_shape() = shape; - instr.set_channel_name(channel_name); - instr.set_cost_estimate_ns(cost_estimate_ns); - return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); - }); -} - -XlaOp XlaBuilder::Complex( - const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions); } @@ -1268,42 +1307,42 @@ XlaOp XlaBuilder::Conj(const XlaOp& operand) { } XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions); } @@ -1311,22 +1350,21 @@ XlaOp XlaBuilder::Not(const XlaOp& operand) { return UnaryOp(HloOpcode::kNot, operand); } -XlaOp XlaBuilder::ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs, broadcast_dimensions); } XlaOp XlaBuilder::ShiftRightLogical( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs, broadcast_dimensions); } @@ -1335,9 +1373,8 @@ XlaOp XlaBuilder::Abs(const XlaOp& operand) { return UnaryOp(HloOpcode::kAbs, operand); } -XlaOp XlaBuilder::Atan2( - const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x, + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions); } @@ -1402,7 +1439,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) { } XlaOp XlaBuilder::Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { + absl::Span permutation) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1417,7 +1454,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, } XlaOp XlaBuilder::Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); @@ -1431,7 +1468,7 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand, }); } -XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, +XlaOp XlaBuilder::Sort(XlaOp keys, absl::optional values, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1459,7 +1496,7 @@ XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional values, } XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions); } @@ -1497,10 +1534,10 @@ XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand, return TernaryOp(HloOpcode::kClamp, min, operand, max); } -XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, +XlaOp XlaBuilder::Map(absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { + absl::Span dimensions, + absl::Span static_operands) { return ReportErrorOrReturn([&]() -> StatusOr { if (!static_operands.empty()) { return Unimplemented("static_operands is not supported in Map"); @@ -1509,8 +1546,8 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); - c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN( @@ -1541,7 +1578,7 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice operands, } XlaOp XlaBuilder::RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, + absl::Span parameters, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1553,7 +1590,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, if (parameters.size() != 2) { return InvalidArgument( "RNG distribution (%s) expects 2 parameters, but got %ld", - RandomDistribution_Name(distribution).c_str(), parameters.size()); + RandomDistribution_Name(distribution), parameters.size()); } break; default: @@ -1600,27 +1637,27 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, }); } -XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + absl::Span slice_sizes) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); - TF_ASSIGN_OR_RETURN(const Shape& gather_indices_shape, - GetShape(gather_indices)); + TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, + GetShape(start_indices)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), - ShapeInference::InferGatherShape(input_shape, gather_indices_shape, - dimension_numbers, window_bounds)); + ShapeInference::InferGatherShape(input_shape, start_indices_shape, + dimension_numbers, slice_sizes)); *instr.mutable_gather_dimension_numbers() = dimension_numbers; - for (int64 bound : window_bounds) { - instr.add_gather_window_bounds(bound); + for (int64 bound : slice_sizes) { + instr.add_gather_slice_sizes(bound); } return AddInstruction(std::move(instr), HloOpcode::kGather, - {input, gather_indices}); + {input, start_indices}); }); } @@ -1682,22 +1719,39 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, }); } -XlaOp XlaBuilder::Reduce( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { +XlaOp XlaBuilder::Reduce(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return Reduce(absl::Span({operand}), + absl::Span({init_value}), computation, + dimensions_to_reduce); +} + +XlaOp XlaBuilder::Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); - TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value)); TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), - ShapeInference::InferReduceShape( - {&operand_shape, &init_shape}, dimensions_to_reduce, - called_program_shape)); + std::vector all_operands; + all_operands.insert(all_operands.end(), operands.begin(), operands.end()); + all_operands.insert(all_operands.end(), init_values.begin(), + init_values.end()); + + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, + GetOperandShapes(all_operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferReduceShape( + operand_shape_ptrs, dimensions_to_reduce, called_program_shape)); for (int64 dim : dimensions_to_reduce) { instr.add_dimensions(dim); @@ -1705,8 +1759,7 @@ XlaOp XlaBuilder::Reduce( AddCalledComputation(computation, &instr); - return AddInstruction(std::move(instr), HloOpcode::kReduce, - {operand, init_value}); + return AddInstruction(std::move(instr), HloOpcode::kReduce, all_operands); }); } @@ -1720,11 +1773,11 @@ XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value, }); } -XlaOp XlaBuilder::ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { +XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1745,9 +1798,9 @@ XlaOp XlaBuilder::ReduceWindow( XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1842,8 +1895,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp XlaBuilder::CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids) { + const XlaOp& operand, absl::Span replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1851,23 +1903,24 @@ XlaOp XlaBuilder::CrossReplicaSum( b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); - return CrossReplicaSum(operand, computation, replica_group_ids, - /*channel_id=*/tensorflow::gtl::nullopt); + return CrossReplicaSum(operand, computation, replica_groups, + /*channel_id=*/absl::nullopt); }); } XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id) { + absl::Span replica_groups, + const absl::optional& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); - for (int64 replica_group_id : replica_group_ids) { - instr.add_replica_group_ids(replica_group_id); + + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; } if (channel_id.has_value()) { @@ -1914,8 +1967,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, HloInstructionProto instr; TF_ASSIGN_OR_RETURN(auto slice_shapes, this->GetOperandShapes(slices)); std::vector slice_shape_ptrs; - c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), - [](const Shape& shape) { return &shape; }); + absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs), + [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs)); @@ -1936,12 +1989,34 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, }); } -XlaOp XlaBuilder::SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { +XlaOp XlaBuilder::CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCollectivePermuteShape(operand_shape)); + + for (const auto& pair : source_target_pairs) { + auto* proto_pair = instr.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + {operand}); + }); +} + +XlaOp XlaBuilder::SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); return SelectAndScatterWithGeneralPadding( @@ -1954,11 +2029,10 @@ XlaOp XlaBuilder::SelectAndScatter( XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -2102,13 +2176,13 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "SendToHost shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(operand_shape).c_str()); + ShapeUtil::HumanString(operand_shape)); } if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { @@ -2147,7 +2221,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, if (!ShapeUtil::IsArray(shape)) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { @@ -2202,7 +2276,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( "of being evaluated at XLA compile time.\n\n" "Please file a usability bug with the framework being used (e.g. " "TensorFlow).", - op_string.c_str()); + op_string); } TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, @@ -2265,7 +2339,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( std::unique_ptr XlaBuilder::CreateSubBuilder( const string& computation_name) { - auto sub_builder = MakeUnique(computation_name); + auto sub_builder = absl::make_unique(computation_name); sub_builder->parent_builder_ = this; sub_builder->die_immediately_on_error_ = this->die_immediately_on_error_; return sub_builder; @@ -2310,8 +2384,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the input are not unique: (%d, %d, %d, " + "%d)", dnum.input_batch_dimension(), dnum.input_feature_dimension(), dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); } @@ -2321,8 +2395,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.kernel_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the weight are not unique: (%d, %d, %d, " + "%d)", dnum.kernel_output_feature_dimension(), dnum.kernel_input_feature_dimension(), dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); @@ -2333,17 +2407,17 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.output_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the output are not unique: (%d, %d, %d, " + "%d)", dnum.output_batch_dimension(), dnum.output_feature_dimension(), dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); } return Status::OK(); } -StatusOr XlaBuilder::AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { +StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); const int64 handle = instructions_.size(); @@ -2354,13 +2428,11 @@ StatusOr XlaBuilder::AddInstruction( } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { - return InvalidArgument("invalid XlaOp with handle %lld", - operand.handle()); + return InvalidArgument("invalid XlaOp with handle %d", operand.handle()); } if (operand.builder_ != this) { return InvalidArgument("Do not add XlaOp from builder %s to builder %s", - operand.builder_->name().c_str(), - this->name().c_str()); + operand.builder_->name(), this->name()); } instr.add_operand_ids(operand.handle()); } @@ -2390,18 +2462,18 @@ StatusOr XlaBuilder::LookUpInstruction( if (op.builder_ == nullptr) { return InvalidArgument( - "invalid XlaOp with handle %lld; the builder of this op is freed", + "invalid XlaOp with handle %d; the builder of this op is freed", op.handle()); } if (op.builder_ != this) { return InvalidArgument( - "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name().c_str(), this->name().c_str()); + op.handle(), op.builder_->name(), this->name()); } if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %lld", op.handle()); + return InvalidArgument("no XlaOp value %d", op.handle()); } return &instructions_[op.handle()]; } @@ -2419,14 +2491,12 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) { return builder->ConstantLiteral(literal); } -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes) { return operand.builder()->Broadcast(operand, broadcast_sizes); } -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions) { return operand.builder()->BroadcastInDim(operand, shape, broadcast_dimensions); } @@ -2436,26 +2506,22 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, return operand.builder()->Pad(operand, padding_value, padding_config); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes) { return operand.builder()->Reshape(operand, dimensions, new_sizes); } -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes) { +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions) { +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Collapse(operand, dimensions); } -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return operand.builder()->Slice(operand, start_indices, limit_indices, strides); } @@ -2467,7 +2533,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, } XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes); } @@ -2476,8 +2542,7 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, return operand.builder()->DynamicUpdateSlice(operand, update, start_indices); } -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, int64 dimension) { return builder->ConcatInDim(operands, dimension); } @@ -2490,7 +2555,7 @@ XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) { return pred.builder()->Select(pred, on_true, on_false); } -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements) { +XlaOp Tuple(XlaBuilder* builder, absl::Span elements) { return builder->Tuple(elements); } @@ -2499,87 +2564,101 @@ XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) { } XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions); } XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions); } XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions); } XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions); } XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions); } XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Le(lhs, rhs, broadcast_dimensions); } -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) { - return lhs.builder()->Dot(lhs, rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->Dot(lhs, rhs, precision_config_proto); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers) { - return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, + precision_config_proto); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return lhs.builder()->Conv(lhs, rhs, window_strides, padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->Conv(lhs, rhs, window_strides, padding, + feature_group_count, precision_config_proto); } XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding); + padding, feature_group_count, + precision_config_proto); } XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides, - padding, dimension_numbers); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->ConvWithGeneralDimensions( + lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, + precision_config_proto); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, - dimension_numbers); + dimension_numbers, feature_group_count, + precision_config_proto); } -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, - dimension_numbers); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto) { + return lhs.builder()->ConvGeneralDilated( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count, precision_config_proto); } XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { + absl::Span fft_length) { return operand.builder()->Fft(operand, fft_type, fft_length); } @@ -2593,106 +2672,106 @@ void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, } XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { return builder->Call(computation, operands); } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { + absl::Span operands, const Shape& shape) { return builder->CustomCall(call_target_name, operands, shape); } -XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape) { - return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape); -} - XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return real.builder()->Complex(real, imag, broadcast_dimensions); } XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); } XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Add(lhs, rhs, broadcast_dimensions); } XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions); } XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions); } XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Div(lhs, rhs, broadcast_dimensions); } XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions); } XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Max(lhs, rhs, broadcast_dimensions); } XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Min(lhs, rhs, broadcast_dimensions); } XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->And(lhs, rhs, broadcast_dimensions); } XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Or(lhs, rhs, broadcast_dimensions); } XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions); } XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); } XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions); } -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions) { return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions); } XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return operand.builder()->Reduce(operand, init_value, computation, dimensions_to_reduce); } +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce) { + return builder->Reduce(operands, init_values, computation, + dimensions_to_reduce); +} + XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation) { return operand.builder()->ReduceAll(operand, init_value, computation); @@ -2700,9 +2779,8 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { return operand.builder()->ReduceWindow(operand, init_value, computation, window_dimensions, window_strides, padding); @@ -2711,25 +2789,24 @@ XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids) { - return operand.builder()->CrossReplicaSum(operand, replica_group_ids); + absl::Span replica_groups) { + return operand.builder()->CrossReplicaSum(operand, replica_groups); } -XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id) { +XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, - replica_group_ids, channel_id); + replica_groups, channel_id); } XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, @@ -2739,11 +2816,17 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, split_count, replica_groups); } +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return operand.builder()->CollectivePermute(operand, source_target_pairs); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter) { return operand.builder()->SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2751,11 +2834,10 @@ XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter) { return operand.builder()->SelectAndScatterWithGeneralPadding( operand, select, window_dimensions, window_strides, padding, source, init_value, scatter); @@ -2764,7 +2846,7 @@ XlaOp SelectAndScatterWithGeneralPadding( XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); } XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return y.builder()->Atan2(y, x, broadcast_dimensions); } @@ -2797,7 +2879,7 @@ XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); } XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); } XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions); } @@ -2815,17 +2897,15 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) { XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); } -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation) { +XlaOp Transpose(const XlaOp& operand, absl::Span permutation) { return operand.builder()->Transpose(operand, permutation); } -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions) { +XlaOp Rev(const XlaOp& operand, absl::Span dimensions) { return operand.builder()->Rev(operand, dimensions); } -XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, - int64 dimension) { +XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension) { return keys.builder()->Sort(keys, std::move(values), dimension); } @@ -2833,10 +2913,9 @@ XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) { return min.builder()->Clamp(min, operand, max); } -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands) { return builder->Map(operands, computation, dimensions, static_operands); } @@ -2868,11 +2947,11 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, mantissa_bits); } -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return input.builder()->Gather(input, gather_indices, dimension_numbers, - window_bounds); + absl::Span slice_sizes) { + return input.builder()->Gather(input, start_indices, dimension_numbers, + slice_sizes); } XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2926,7 +3005,7 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); } -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens) { +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens) { return builder->AfterAll(tokens); } @@ -2953,11 +3032,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, grad_output, epsilon, feature_index); } -XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); - return builder->ReportErrorOrReturn( - builder->AddInstruction(std::move(instr), HloOpcode::kIota)); +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { + return builder->Iota(type, size); +} + +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->Iota(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 9403d7ca8dabc80a3964b50d29f158a98091f843..59fbc664f2b35fd00f9b9094d6147847d03797ea 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -154,12 +154,10 @@ class XlaBuilder { // Clears the sharding. Ops will be sharded according to the default placement // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + void ClearSharding() { sharding_ = absl::nullopt; } // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } + const absl::optional& sharding() const { return sharding_; } // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is @@ -296,7 +294,7 @@ class XlaBuilder { template XlaOp ConstantR0(NativeT value); template - XlaOp ConstantR1(tensorflow::gtl::ArraySlice values); + XlaOp ConstantR1(absl::Span values); XlaOp ConstantR1(const tensorflow::core::Bitmap& values); template XlaOp ConstantR2( @@ -338,7 +336,7 @@ class XlaBuilder { // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -357,9 +355,8 @@ class XlaBuilder { // will generate output // [1 , 1] // [2 , 2] - XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -372,15 +369,13 @@ class XlaBuilder { // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". - XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -400,8 +395,7 @@ class XlaBuilder { // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. - XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -414,10 +408,9 @@ class XlaBuilder { // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice - XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -438,7 +431,7 @@ class XlaBuilder { // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -461,8 +454,7 @@ class XlaBuilder { // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. - XlaOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + XlaOp ConcatInDim(absl::Span operands, int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -473,84 +465,96 @@ class XlaBuilder { XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. - XlaOp Tuple(tensorflow::gtl::ArraySlice elements); + XlaOp Tuple(absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. - XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + XlaOp DotGeneral( + const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -572,25 +576,14 @@ class XlaBuilder { // Enqueues a call instruction onto the computation. XlaOp Call(const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - XlaOp HostCompute(tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -599,65 +592,70 @@ class XlaBuilder { // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); + XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + + // Reduces several arrays simultaneously among the provided dimensions, given + // "computation" as a reduction operator. + XlaOp Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -667,25 +665,23 @@ class XlaBuilder { // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. - XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids = {}); + XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -694,10 +690,11 @@ class XlaBuilder { // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // - // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, - // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group. Allreduce will be applied within + // subgroups. For example, we have 4 replicas, then + // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, + // replica 1 and 3 are in subgroup 1. // // - `channel_id`: for Allreduce nodes from different modules, if they have // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will @@ -706,22 +703,25 @@ class XlaBuilder { // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& channel_id = - tensorflow::gtl::nullopt); + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. - // - // TODO(b/110096724): This is NOT YET ready to use. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); + // Enqueues an operation that do an CollectivePermute of the operand cross + // cores. + XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, const XlaOp& source, const XlaOp& init_value, const XlaComputation& scatter); @@ -730,18 +730,17 @@ class XlaBuilder { // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -788,7 +787,7 @@ class XlaBuilder { // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -796,6 +795,12 @@ class XlaBuilder { // entry was NaN. XlaOp IsFinite(const XlaOp& operand); + // Enqueues an iota operation onto the computation. + XlaOp Iota(const Shape& shape, int64 iota_dimension); + + // Enqueues a rank-1 iota operation onto the computation. + XlaOp Iota(PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, @@ -812,14 +817,12 @@ class XlaBuilder { XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. - XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); + XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). - XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -837,18 +840,16 @@ class XlaBuilder { // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and a tensor with their // corresponding values as the second element. - XlaOp Sort(XlaOp keys, - tensorflow::gtl::optional values = tensorflow::gtl::nullopt, + XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. - XlaOp Map(tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -873,9 +874,9 @@ class XlaBuilder { const int mantissa_bits); // Enqueues a Gather node onto the computation. - XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -903,7 +904,7 @@ class XlaBuilder { // Enqueues an AfterAll operation with no operands producing a token-shaped // value. - XlaOp AfterAll(tensorflow::gtl::ArraySlice tokens); + XlaOp AfterAll(absl::Span tokens); // Enqueues a Recv node onto the computation. The data comes from a Send // instruction that shares the same channel handle and its shape must @@ -950,9 +951,8 @@ class XlaBuilder { const XlaOp& grad_output, float epsilon, int64 feature_index); - StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands = {}); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands = {}); void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); @@ -966,19 +966,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Internal helper method that does the building for an arbitrary ternary op. XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs); XlaOp RngOp(RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); + absl::Span parameters, const Shape& shape); - StatusOr InDimBroadcast( - const Shape& shape, const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + StatusOr InDimBroadcast(const Shape& shape, const XlaOp& operand, + absl::Span broadcast_dimensions); // Internal helper method that creates a sequence of instructions that // performs an explicit broadcast of the operand to the target shape. @@ -994,7 +992,7 @@ class XlaBuilder { // Returns shapes for the operands. StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) const; + absl::Span operands) const; // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful @@ -1011,12 +1009,11 @@ class XlaBuilder { // Helper function for creating a Window proto from user-supplied data. // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation) const; + StatusOr MakeWindow(absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) const; string name_; // Name to use for the built computation. @@ -1045,7 +1042,7 @@ class XlaBuilder { // Sharding for this operator. This is structured as a "model"-like operation, // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; + absl::optional sharding_; // Mode bit that indicates whether to die when a first error is encountered. bool die_immediately_on_error_ = false; @@ -1060,7 +1057,7 @@ class XlaBuilder { friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template friend XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); + absl::Span values); friend XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template @@ -1100,178 +1097,183 @@ class XlaBuilder { friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); friend XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); friend XlaOp BroadcastInDim( const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); + const absl::Span broadcast_dimensions); friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, const PaddingConfig& padding_config); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); - friend XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); + friend XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); friend XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); friend XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, const XlaOp& start_indices); friend XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - int64 dimension); + absl::Span operands, int64 dimension); friend void Trace(const string& tag, const XlaOp& operand); friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); - friend XlaOp Tuple(XlaBuilder* builder, - tensorflow::gtl::ArraySlice elements); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); + absl::Span broadcast_dimensions); + friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_number, + const PrecisionConfigProto* precision_config_proto); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - friend XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, Padding padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); + friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count, + const PrecisionConfigProto* precision_config_proto); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - friend XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Not(const XlaOp& operand); - friend XlaOp ShiftLeft( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - friend XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions); friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation); - friend XlaOp ReduceWindow( - const XlaOp& operand, const XlaOp& init_value, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); friend XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids); - friend XlaOp CrossReplicaSum( - const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids, - const tensorflow::gtl::optional& channel_id); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups); + friend XlaOp CrossReplicaSum(const XlaOp& operand, + const XlaComputation& computation, + absl::Span replica_groups, + const absl::optional& channel_id); friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); - friend XlaOp SelectAndScatter( - const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + friend XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + friend XlaOp SelectAndScatter(const XlaOp& operand, + const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, const XlaOp& source, + const XlaOp& init_value, + const XlaComputation& scatter); friend XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); friend XlaOp Abs(const XlaOp& operand); friend XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp Exp(const XlaOp& operand); friend XlaOp Expm1(const XlaOp& operand); friend XlaOp Floor(const XlaOp& operand); @@ -1287,28 +1289,25 @@ class XlaBuilder { friend XlaOp Real(const XlaOp& operand); friend XlaOp Imag(const XlaOp& operand); friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); friend XlaOp IsFinite(const XlaOp& operand); - // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota - // in xla/client/lib/numeric.h with this (renamed to xla::Iota). - friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); friend XlaOp Neg(const XlaOp& operand); friend XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); - friend XlaOp Rev(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); - friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional values, - int64 dimension); + absl::Span permutation); + friend XlaOp Rev(const XlaOp& operand, absl::Span dimensions); + friend XlaOp Sort(XlaOp keys, absl::optional values, int64 dimension); friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); - friend XlaOp Map(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + absl::Span dimensions, + absl::Span static_operands); friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape); friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape); @@ -1320,9 +1319,9 @@ class XlaBuilder { const XlaComputation& false_computation); friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); - friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, + friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, @@ -1356,8 +1355,7 @@ class XlaBuilder { const Shape& shape_with_layout, const string& outfeed_config); friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, - tensorflow::gtl::ArraySlice tokens); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1365,7 +1363,7 @@ class XlaBuilder { class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, - tensorflow::gtl::optional sharding) + absl::optional sharding) : builder_(builder), prev_sharding_(builder->sharding()) { SetSharding(sharding); } @@ -1377,7 +1375,7 @@ class XlaScopedShardingAssignment { ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } private: - void SetSharding(const tensorflow::gtl::optional& sharding) { + void SetSharding(const absl::optional& sharding) { if (sharding.has_value()) { builder_->SetSharding(sharding.value()); } else { @@ -1386,7 +1384,7 @@ class XlaScopedShardingAssignment { } xla::XlaBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; + absl::optional prev_sharding_; }; // Free functions for building XlaOps. The intention is that these will @@ -1421,8 +1419,7 @@ XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); template XlaOp ConstantR0(XlaBuilder* builder, NativeT value); template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values); +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values); template XlaOp ConstantR2(XlaBuilder* builder, @@ -1471,8 +1468,7 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value); // The new dimensions index into copies of the operand, i.e. // // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(const XlaOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); +XlaOp Broadcast(const XlaOp& operand, absl::Span broadcast_sizes); // Performs in-dimension-style broadcast. // @@ -1491,9 +1487,8 @@ XlaOp Broadcast(const XlaOp& operand, // will generate output // [1 , 1] // [2 , 2] -XlaOp BroadcastInDim( - const XlaOp& operand, const Shape& shape, - const tensorflow::gtl::ArraySlice broadcast_dimensions); +XlaOp BroadcastInDim(const XlaOp& operand, const Shape& shape, + const absl::Span broadcast_dimensions); // Enqueues a pad operation onto the computation that pads the given value on // the edges as well as between the elements of the input. padding_config @@ -1506,15 +1501,13 @@ XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value, // given, followed by reshaping it into the shape with the given dimension // sizes (also major to minor). Conceptually, this is a limited form of // "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span dimensions, + absl::Span new_sizes); // Enqueues an operation onto the computation that collapses the operand, from // first to last dimension (C order), then reshapes it to the given dimension // sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(const XlaOp& operand, - tensorflow::gtl::ArraySlice new_sizes); +XlaOp Reshape(const XlaOp& operand, absl::Span new_sizes); // Wrapper for Reshape. // Enqueues an operation to collapse the provided dimensions; e.g. an @@ -1534,8 +1527,7 @@ XlaOp Reshape(const XlaOp& operand, // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(const XlaOp& operand, - tensorflow::gtl::ArraySlice dimensions); +XlaOp Collapse(const XlaOp& operand, absl::Span dimensions); // Enqueues a slice operation onto the computation that slices the operand // from the start indices to the limit indices; e.g. @@ -1548,10 +1540,9 @@ XlaOp Collapse(const XlaOp& operand, // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D // range notation. // The strides parameter determines the stride over the slice -XlaOp Slice(const XlaOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +XlaOp Slice(const XlaOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Enqueues a slice operation in a given dimension, taking all other // dimensions as they are; e.g. if dimno is 1 from start_index 2 to @@ -1572,7 +1563,7 @@ XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index, // Slice index calculations are computed modulo input dimension sizes to // prevent dynamic start indices from generating out-of-bound array accesses. XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Enqueues a dynamic update slice operation onto the computation, which // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. @@ -1595,8 +1586,8 @@ XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update, // Enqueues a concatenate instruction onto the computation. 'operands' must // have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, int64 dimension); +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64 dimension); // Enqueue a tracing operation onto the computation; the computation will emit // a logging message with the operand. @@ -1607,82 +1598,90 @@ void Trace(const string& tag, const XlaOp& operand); XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false); // Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice elements); +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); // Enqueues a tuple-element-get instruction onto the computation. XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a dot instruction onto the computation. -XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs); +XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding); + absl::Span window_strides, Padding padding, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. XlaOp ConvWithGeneralDimensions( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfigProto* precision_config_proto = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. XlaOp Fft(const XlaOp& operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. @@ -1714,26 +1713,14 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, // Enqueues a call instruction onto the computation. XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Enqueues a custom call instruction onto the computation. // During code generation, a call instruction is emitted which targets a // symbol with the name |call_target_name|. The |operands| are passed to the // call instruction. |shape| is the resultant shape. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - -// Enqueues a pseudo-op to represent host-side computation data-dependencies. -// During code generation, host send and receive operations will be generated -// to transfer |operands| to the host and a single result of |shape| back to -// the device. Host send/recv operations are emitted using |channel_name|. -// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO -// instruction scheduling. -XlaOp HostCompute(XlaBuilder* builder, - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, - const Shape& shape); + absl::Span operands, const Shape& shape); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1742,65 +1729,70 @@ XlaOp HostCompute(XlaBuilder* builder, // Enqueues a complex compose instruction onto the computation. XlaOp Complex(const XlaOp& real, const XlaOp& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a complex conjugate instruction onto the computation. XlaOp Conj(const XlaOp& operand); // Enqueues an add instruction onto the computation. XlaOp Add(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a subtract instruction onto the computation. XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a multiply instruction onto the computation. XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a divide instruction onto the computation. XlaOp Div(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a remainder instruction onto the computation. XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a max instruction onto the computation. XlaOp Max(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues a min instruction onto the computation. XlaOp Min(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Element-wise logical operators XlaOp And(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Or(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); XlaOp Not(const XlaOp& operand); XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); -XlaOp ShiftRightLogical( - const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs, + absl::Span broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); + +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); // Convenience wrapper around the above that reduces all the dimensions in the // operand shape. @@ -1810,25 +1802,23 @@ XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value, // Enqueues a windowed reduce instruction onto the computation. XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); // As ReduceWindow(), but the padding is given in the format // returned by MakePadding(). XlaOp ReduceWindowWithGeneralPadding( const XlaOp& operand, const XlaOp& init_value, const XlaComputation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding); // Returns the sum of the operand value within each subgroup of replicas. All // replicas supply one input to the sum and all replicas receive the resulting // sum for each subgroup. -XlaOp CrossReplicaSum( - const XlaOp& operand, - tensorflow::gtl::ArraySlice replica_group_ids = {}); +XlaOp CrossReplicaSum(const XlaOp& operand, + absl::Span replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -1837,52 +1827,61 @@ XlaOp CrossReplicaSum( // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // -// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all -// replicas belong to one group. Allreduce will be applied within subgroups. -// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, -// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// - `replica_groups`: each ReplicaGroup contains a list of replica id. If +// empty, all replicas belong to one group. Allreduce will be applied within +// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} +// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. // // TODO(b/79737069): Rename this to AllReduce when it's ready to use. -XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice replica_group_ids = {}, - const tensorflow::gtl::optional& - channel_id = tensorflow::gtl::nullopt); +XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. -// -// TODO(b/110096724): This is NOT YET ready to use. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups = {}); +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - Padding padding, const XlaOp& source, - const XlaOp& init_value, const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, Padding padding, + const XlaOp& source, const XlaOp& init_value, + const XlaComputation& scatter); // As SelectAndScatter(), but the padding is given in the format // returned by MakePadding(). XlaOp SelectAndScatterWithGeneralPadding( const XlaOp& operand, const XlaComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const XlaOp& source, const XlaOp& init_value, - const XlaComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const XlaOp& source, + const XlaOp& init_value, const XlaComputation& scatter); // Enqueues an abs instruction onto the computation. XlaOp Abs(const XlaOp& operand); // Enqueues a atan2 instruction onto the computation. XlaOp Atan2(const XlaOp& y, const XlaOp& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); @@ -1929,7 +1928,7 @@ XlaOp Imag(const XlaOp& operand); // Enqueues a lhs^rhs computation onto the computation. XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + absl::Span broadcast_dimensions = {}); // Enqueues an operator that tests if the operand's values are finite, i.e., // not Inf or NaN. Defined only for floating-point types. Returns an array of @@ -1937,6 +1936,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // entry was NaN. XlaOp IsFinite(const XlaOp& operand); +// Enqueues an iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); @@ -1951,13 +1956,12 @@ XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type); XlaOp Neg(const XlaOp& operand); // Enqueues a transpose instruction onto the computation. -XlaOp Transpose(const XlaOp& operand, - tensorflow::gtl::ArraySlice permutation); +XlaOp Transpose(const XlaOp& operand, absl::Span permutation); // Enqueues a reverse instruction onto the computation. The order of the // elements in the given dimensions is reversed (i.e., the element at index i // is moved to index dimension_size - 1 - i). -XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); +XlaOp Rev(const XlaOp& operand, absl::Span dimensions); // Enqueues a sort (as increasing order) instruction onto the computation. // If only keys are provided: @@ -1975,18 +1979,16 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions); // * The result is a tuple that consists of a sorted tensor of keys (along the // provided dimension, as above) as the first element, and a tensor with their // corresponding values as the second element. -XlaOp Sort(XlaOp keys, - tensorflow::gtl::optional values = tensorflow::gtl::nullopt, +XlaOp Sort(XlaOp keys, absl::optional values = absl::nullopt, int64 dimension = -1); // Enqueues a clamp instruction onto the computation. XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max); // Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice operands, - const XlaComputation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, absl::Span dimensions, + absl::Span static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the // computation. @@ -2011,9 +2013,9 @@ XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits, const int mantissa_bits); // Enqueues a Gather node onto the computation. -XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, +XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Enqueues a Scatter node onto the computation. XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, @@ -2071,7 +2073,7 @@ XlaOp CreateToken(XlaBuilder* builder); // Enqueues an AfterAll instruction which produces a token-shaped value and // takes a variadic number of token-shaped operands. The number of operands must // be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice tokens); +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -2119,7 +2121,7 @@ XlaOp XlaBuilder::ConstantR0(NativeT value) { } template -XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice values) { +XlaOp XlaBuilder::ConstantR1(absl::Span values) { return ConstantLiteral(*LiteralUtil::CreateR1(values)); } @@ -2196,8 +2198,7 @@ XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { } template -XlaOp ConstantR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice values) { +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); } diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e..7c37ed00cd3dcc214fb0b36c0161d3c39a5bf8c8 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); } +TEST_F(XlaBuilderTest, CollectivePermute) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index 3543d41fc2656ec028646edebc0bf5b6af7f67a5..22c9e83bb2ae9e3e205bdd480b64c703e31c6ffd 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -32,7 +32,7 @@ StatusOr> XlaComputation::Snapshot() const { if (IsNull()) { return InvalidArgument("Computation is invalid."); } - auto session = MakeUnique(); + auto session = absl::make_unique(); *session->mutable_hlo()->mutable_hlo_module() = proto_; return std::move(session); } diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 1a51fdee680721a4a03fa5de79a81746d92af76b..6d51126d882f87a84b054e9db599b995868824bf 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -21,8 +21,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,8 +30,8 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" string DeviceIdentifier(se::StreamExecutor* stream_exec) { - return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", - stream_exec->device_ordinal()); + return absl::StrCat(stream_exec->platform()->Name(), ":", + stream_exec->device_ordinal()); } } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index ffd1fb79e986f82e1c2721f0eefbf3b4c0838e41..3fadabcf5207097aa875d654320b930b1ed94ad3 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -18,16 +18,16 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index) { + const Shape& shape, absl::Span multi_index) { DCHECK_EQ(shape.dimensions_size(), multi_index.size()); // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); @@ -36,7 +36,7 @@ namespace xla { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" - << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") + << "\n\tindex: " << absl::StrJoin(multi_index, ",") << "\n\tshape: " << ShapeUtil::HumanString(shape); } @@ -118,8 +118,8 @@ namespace xla { return multi_index; } -/* static */ bool IndexUtil::BumpIndices( - const Shape& shape, tensorflow::gtl::MutableArraySlice indices) { +/* static */ bool IndexUtil::BumpIndices(const Shape& shape, + absl::Span indices) { for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) { int64 limit = shape.dimensions(dimno); if (indices[dimno] + 1 < limit) { @@ -149,8 +149,8 @@ namespace xla { return stride; } -/* static */ bool IndexUtil::IndexInBounds( - const Shape& shape, tensorflow::gtl::ArraySlice index) { +/* static */ bool IndexUtil::IndexInBounds(const Shape& shape, + absl::Span index) { int64 rank = ShapeUtil::Rank(shape); if (rank != index.size()) { return false; @@ -163,9 +163,8 @@ namespace xla { return true; } -/* static */ int IndexUtil::CompareIndices( - tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs) { +/* static */ int IndexUtil::CompareIndices(absl::Span lhs, + absl::Span rhs) { int64 rank = lhs.size(); CHECK_EQ(rhs.size(), rank); for (int64 dim = 0; dim < rank; ++dim) { diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index 142006f2626e83d3254f2de65fc28fd5d6694e53..2979cf87dde92893ce2151cb09b46c8db8473b31 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -35,7 +35,7 @@ class IndexUtil { // on the shape and its layout. The first index in the multi_index is // dimension 0. static int64 MultidimensionalIndexToLinearIndex( - const Shape& shape, tensorflow::gtl::ArraySlice multi_index); + const Shape& shape, absl::Span multi_index); // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional @@ -58,8 +58,7 @@ class IndexUtil { // // Returns true iff the indices were successfully bumped; false if we've hit // the limit where it can no longer be bumped in-bounds. - static bool BumpIndices(const Shape& shape, - tensorflow::gtl::MutableArraySlice indices); + static bool BumpIndices(const Shape& shape, absl::Span indices); // Calculates the stride size (in number of elements, not byte size) of a // given logical shape dimension (from 0 to rank-1). If available, padded @@ -71,15 +70,14 @@ class IndexUtil { // Returns true iff the given multi-index is contained in the bounds for the // shape. - static bool IndexInBounds(const Shape& shape, - tensorflow::gtl::ArraySlice index); + static bool IndexInBounds(const Shape& shape, absl::Span index); // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is // returned. - static int CompareIndices(tensorflow::gtl::ArraySlice lhs, - tensorflow::gtl::ArraySlice rhs); + static int CompareIndices(absl::Span lhs, + absl::Span rhs); private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc index 7c4efdee484d9530a69b31cbe3a0d69a8a3cffa7..93522d2ca87a7eba8d3c7533785c54e63ce507b0 100644 --- a/tensorflow/compiler/xla/index_util_test.cc +++ b/tensorflow/compiler/xla/index_util_test.cc @@ -142,13 +142,13 @@ TEST(IndexUtilTest, LinearToMultiToLinear) { TEST(IndexUtilTest, BumpIndices2x2) { auto shape = ShapeUtil::MakeShape(S32, {2, 2}); std::vector indices = {0, 0}; - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(0, 1)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 0)); - EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_TRUE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); EXPECT_THAT(indices, ::testing::ElementsAre(1, 1)); - EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices)); + EXPECT_FALSE(IndexUtil::BumpIndices(shape, absl::MakeSpan(indices))); } } // namespace diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h index a8bb8c7a7e6784e555f4e9dad73ecc78c668ac42..3a3ee21e7635b9dee61f59e4e8c69eec3d420c86 100644 --- a/tensorflow/compiler/xla/iterator_util.h +++ b/tensorflow/compiler/xla/iterator_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ #include #include @@ -95,4 +95,4 @@ UnwrappingIterator MakeUnwrappingIterator(NestedIter iter) { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc index 7bc3189507ec5233c6983eb26cfb07dc9bfadd52..ec8b66df2db0b9d8c045fbf6133f607e57c81c26 100644 --- a/tensorflow/compiler/xla/iterator_util_test.cc +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/test.h" namespace xla { @@ -27,7 +27,7 @@ namespace { TEST(UnwrappingIteratorTest, Simple) { std::vector> v; for (int i = 0; i < 3; ++i) { - v.push_back(MakeUnique(i)); + v.push_back(absl::make_unique(i)); } int i = 0; for (auto iter = MakeUnwrappingIterator(v.begin()); @@ -51,7 +51,7 @@ TEST(UnwrappingIteratorTest, PostincrementOperator) { TEST(UnwrappingIteratorTest, StdFind) { std::list> l; for (int i = 0; i < 3; ++i) { - l.push_back(MakeUnique(i)); + l.push_back(absl::make_unique(i)); } EXPECT_EQ(l.begin()->get(), *std::find(MakeUnwrappingIterator(l.begin()), diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index b72d190d54591384392e79e73e90cf52df04a902..d310335618ded7b581e6ed632223218585bb791f 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -56,7 +56,7 @@ void SetDefaultLayoutToContainer( } // namespace /* static */ Layout LayoutUtil::MakeLayout( - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span minor_to_major) { Layout layout; layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { @@ -66,7 +66,7 @@ void SetDefaultLayoutToContainer( } /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor) { + absl::Span major_to_minor) { Layout layout; layout.set_format(DENSE); for (int i = major_to_minor.size() - 1; i >= 0; i--) { @@ -169,7 +169,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return ValidateLayoutForShape(shape.layout(), shape); } else { @@ -177,7 +177,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (shape.has_layout()) { return InvalidArgument( "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -194,7 +194,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.padded_dimensions_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -202,17 +202,17 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.format() == INVALID_FORMAT) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); + layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", + "but shape is rank %d: {%s}; shape: %s", layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + absl::StrJoin(layout.minor_to_major(), ", "), + shape.ShortDebugString()); } std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); @@ -221,12 +221,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + HumanString(layout)); } if (dimensions_in_layout[dim]) { return InvalidArgument( "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); + HumanString(layout)); } dimensions_in_layout[dim] = true; } @@ -234,14 +234,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.padded_dimensions_size() > 0) { if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", + "layout has %d padded dimensions, but shape is rank %d", layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); } for (int i = 0; i < layout.padded_dimensions_size(); ++i) { if (layout.padded_dimensions(i) < shape.dimensions(i)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", + "for dimension %d, dimension padding (%d) is smaller than " + "the dimension size (%d) of the shape", i, layout.padded_dimensions(i), shape.dimensions(i)); } } @@ -307,7 +307,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( +/* static */ absl::Span LayoutUtil::PaddedDimensions( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().padded_dimensions()); @@ -363,13 +363,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Shape& shape) { CHECK(IsDenseArray(shape)); return AsInt64Slice(shape.layout().minor_to_major()); } -/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( +/* static */ absl::Span LayoutUtil::MinorToMajor( const Layout& layout) { CHECK(layout.format() == DENSE); return AsInt64Slice(layout.minor_to_major()); @@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ string LayoutUtil::HumanString(const Layout& layout) { if (IsSparse(layout)) { - return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), - "}"); + return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); } CHECK(IsDense(layout)); - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); + return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); } namespace { @@ -474,7 +472,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } /* static */ bool LayoutUtil::AreDimensionsConsecutive( - const Layout& layout, tensorflow::gtl::ArraySlice dims) { + const Layout& layout, absl::Span dims) { CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 739bbe73675c7fb855627006028eafdf703d6540..b78883c2d870043032306637730c4666665125a8 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,10 +20,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -34,11 +34,11 @@ class LayoutUtil { public: // Creates a layout with the given minor-to-major dimension order. (This is a // convenience function for protobuf construction.) - static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + static Layout MakeLayout(absl::Span minor_to_major); // Similar to MakeLayout, but take indices in reverse order. static Layout MakeLayoutFromMajorToMinor( - tensorflow::gtl::ArraySlice major_to_minor); + absl::Span major_to_minor); // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) @@ -104,8 +104,7 @@ class LayoutUtil { // Returns the padded_dimensions array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice PaddedDimensions( - const Shape& shape); + static absl::Span PaddedDimensions(const Shape& shape); // Returns the given index of the padded_dimensions array for the given Shape. // Requires that the shape is an array and has a dense layout. @@ -138,8 +137,8 @@ class LayoutUtil { // Returns the minor_to_major array for the given Shape. Requires that the // shape is an array and has a dense layout. - static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); - static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + static absl::Span MinorToMajor(const Shape& shape); + static absl::Span MinorToMajor(const Layout& layout); // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. @@ -196,7 +195,7 @@ class LayoutUtil { // Returns whether the given dimensions are consecutive in the given layout, // not necessarily in the order given. static bool AreDimensionsConsecutive(const Layout& layout, - tensorflow::gtl::ArraySlice dims); + absl::Span dims); // Compute a hash for `layout`. static size_t Hash(const Layout& layout); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index e4c825450dcd45a8fbeaacbb2ad145f94307176f..f25dae6ff411133c74502039f441060f1329ffd4 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -27,15 +27,15 @@ namespace { class LayoutUtilTest : public ::testing::Test { protected: Shape MakeShapeWithLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + absl::Span dimensions, + absl::Span minor_to_major) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } Shape MakeShapeWithSparseLayout(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, + absl::Span dimensions, int64 max_sparse_elements) { Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 89353448e29ec3d97275dac288e23aa8e96e31b2..3e79129aafd234e5eab05d205f2017b54057795e 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -39,6 +40,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", ], ) @@ -56,6 +58,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -73,5 +76,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 1bf8948ef6ded56573d588258c3d9bbfaa55a50d..0d3136b0cc6a3a695eacb98c16200e46a144c571 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -17,9 +17,9 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace legacy_flags { @@ -87,7 +87,7 @@ void AllocateFlags() { // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); for (const auto& passname : disabled_passes) { flag_values->add_xla_disable_hlo_passes(passname); } @@ -316,6 +316,13 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn), flag_values->xla_cpu_use_mkl_dnn(), "Generate calls to MKL-DNN in the CPU backend."), + tensorflow::Flag( + "xla_gpu_crash_on_verification_failures", + bool_setter_for( + &DebugOptions::set_xla_gpu_crash_on_verification_failures), + flag_values->xla_gpu_crash_on_verification_failures(), + "Crashes the program on extra verification failures, e.g. cuDNN " + "cross checking failures"), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index e9cf435d83d8345e974d83f8e5340dafeba8e3b2..ee7eb019c07cf898e48886955b18710146644cac 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace legacy_flags { @@ -30,7 +30,7 @@ template void parse_xla_backend_extra_options(T* extra_options_map, string comma_separated_values) { std::vector extra_options_parts = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); // The flag contains a comma-separated list of options; some options // have arguments following "=", some don't. @@ -59,8 +59,7 @@ void parse_xla_backend_extra_options(T* extra_options_map, inline bool parse_xla_reduce_precision_option( HloReducePrecisionOptions* options, string option_string) { // Split off "LOCATION" from remainder of string. - std::vector eq_split = - tensorflow::str_util::Split(option_string, '='); + std::vector eq_split = absl::StrSplit(option_string, '='); if (eq_split.size() != 2) { return false; } @@ -80,26 +79,25 @@ inline bool parse_xla_reduce_precision_option( } // Split off "E,M" from remainder of string. - std::vector colon_split = - tensorflow::str_util::Split(eq_split[1], ':'); + std::vector colon_split = absl::StrSplit(eq_split[1], ':'); if (colon_split.size() != 2) { return false; } // Split E and M, and parse. std::vector bitsizes; - if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',', - &bitsizes) || - bitsizes.size() != 2) { - return false; + for (const auto& s : absl::StrSplit(colon_split[0], ',')) { + bitsizes.emplace_back(); + if (!absl::SimpleAtoi(s, &bitsizes.back())) { + return false; + } } options->set_exponent_bits(bitsizes[0]); options->set_mantissa_bits(bitsizes[1]); // Split off OPS comma-separated list from remainder of string, if the // remainder exists. - std::vector semicolon_split = - tensorflow::str_util::Split(colon_split[1], ';'); + std::vector semicolon_split = absl::StrSplit(colon_split[1], ';'); if (semicolon_split.size() > 2) { return false; } @@ -113,8 +111,7 @@ inline bool parse_xla_reduce_precision_option( options->add_opcodes_to_suffix(i); } } else { - std::vector opcodes = - tensorflow::str_util::Split(opcode_string, ','); + std::vector opcodes = absl::StrSplit(opcode_string, ','); for (const string& opcode : opcodes) { bool found = false; for (int i = 0; i < HloOpcodeCount(); i++) { @@ -132,8 +129,7 @@ inline bool parse_xla_reduce_precision_option( // Process the NAMES string, if it exists. if (semicolon_split.size() == 2) { - std::vector opnames = - tensorflow::str_util::Split(semicolon_split[1], ','); + std::vector opnames = absl::StrSplit(semicolon_split[1], ','); for (const string& opname : opnames) { if (opname.length() > 0) { options->add_opname_substrings_to_suffix(opname); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc index 0ed788a9676fe9b1bd06fb3ceabf627c108a2c70..6f197aec53c7596e84437a03affa9118f22f5a1d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index 7b6ae311c1099dccb8dceb2f49743c1b185cd5ab..138c0c852e2bb0527d171f25b4d96cedc5671516 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" @@ -106,8 +106,8 @@ TEST(ParseFlagsFromEnv, File) { if (tmp_dir == nullptr) { tmp_dir = kTempDir; } - string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", - tmp_dir, getpid()); + string tmp_file = + absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid()); FILE* fp = fopen(tmp_file.c_str(), "w"); CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; for (int i = 0; kTestFlagString[i] != '\0'; i++) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 36e472568ecfdb97c828817ed339260ee7878723..3f7635bd400c6ec87e0e3a739658272e906a72fb 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -22,6 +22,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,19 +34,15 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; // Converts between little and big endian. @@ -73,7 +73,7 @@ std::ostream& operator<<(std::ostream& out, const Literal& literal) { MutableLiteralBase::StrideConfig::StrideConfig( const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : dimensions(dimensions), base(dimensions.size(), 0), step(dimensions.size(), 1) { @@ -134,7 +134,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { Literal::Literal(const Shape& shape, bool allocate_arrays) : MutableLiteralBase() { - shape_ = MakeUnique(shape); + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); root_piece_->set_subshape(shape_.get()); @@ -175,7 +175,7 @@ Literal& Literal::operator=(Literal&& other) { } std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); literal->root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { @@ -197,14 +197,13 @@ SparseIndexArray* MutableLiteralBase::sparse_indices( template Status MutableLiteralBase::CopySliceFromInternal( - const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { + const LiteralBase& src_literal, absl::Span src_base, + absl::Span dest_base, absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); auto linear_index = [](const Shape& shape, - tensorflow::gtl::ArraySlice multi_index) { + absl::Span multi_index) { return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); }; @@ -232,7 +231,7 @@ Status MutableLiteralBase::CopySliceFromInternal( MutableLiteralBase::StrideConfig stride_config(src_literal.shape(), shape(), copy_size); - auto copy_proc = [&](tensorflow::gtl::ArraySlice indexes) { + auto copy_proc = [&](absl::Span indexes) { // Map from multi-dimensional index, to source index. std::transform(indexes.begin(), indexes.end(), src_base.begin(), src_indexes.begin(), std::plus()); @@ -257,10 +256,9 @@ Status MutableLiteralBase::CopySliceFromInternal( return Status::OK(); } -Status MutableLiteralBase::CopyElementFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index) { +Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, + absl::Span src_index, + absl::Span dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex( src_literal.shape(), src_index); @@ -289,7 +287,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = MakeUnique(proto.shape()); + auto literal = absl::make_unique(proto.shape()); TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -303,7 +301,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", + "Expected %d tuple elements in LiteralProto, has %d", ShapeUtil::TupleElementCount(piece->subshape()), proto_element->tuple_literals_size()); } @@ -355,9 +353,9 @@ namespace { // Copies the elements in 'src' to 'dest'. The shape and layout of the data in // the array slices are indicated by dest_shape and src_shape respectively. template -void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, - tensorflow::gtl::ArraySlice src, - const Shape& dest_shape, const Shape& src_shape) { +void CopyElementsBetween(absl::Span dest, + absl::Span src, const Shape& dest_shape, + const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; @@ -366,7 +364,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, do { dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; - } while (IndexUtil::BumpIndices(dest_shape, &index)); + } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index))); } } // namespace @@ -404,7 +402,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { default: return Unimplemented( "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); + PrimitiveType_Name(subshape().element_type())); } } return Status::OK(); @@ -420,8 +418,8 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { return InvalidArgument( "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_subshape)); } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -458,8 +456,8 @@ Status Literal::MoveFrom(Literal&& src_literal, if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { return InvalidArgument( "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_literal.shape())); } src_literal.root_piece_->ForEachSubpiece( @@ -479,7 +477,7 @@ Status Literal::MoveFrom(Literal&& src_literal, dest_piece.set_sparse_indices(src_piece.sparse_indices()); }); - src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); delete src_literal.root_piece_; src_literal.root_piece_ = new LiteralBase::Piece(); src_literal.root_piece_->set_subshape(src_literal.shape_.get()); @@ -487,11 +485,10 @@ Status Literal::MoveFrom(Literal&& src_literal, return Status::OK(); } -Status MutableLiteralBase::CopySliceFrom( - const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +Status MutableLiteralBase::CopySliceFrom(const LiteralSlice& src_literal, + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size) { TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) << ShapeUtil::HumanString(src_literal.shape()); @@ -566,7 +563,7 @@ std::unique_ptr LiteralBase::Relayout( Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = MakeUnique(new_shape); + auto result = absl::make_unique(new_shape); TF_CHECK_OK(result->CopyFrom(*this)); return result; } @@ -591,8 +588,7 @@ std::unique_ptr LiteralBase::Relayout( } StatusOr> LiteralBase::Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const { + const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); } @@ -602,7 +598,7 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = MakeUnique(result_shape); + std::unique_ptr result = absl::make_unique(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in @@ -615,7 +611,7 @@ StatusOr> LiteralBase::Broadcast( ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); ShapeUtil::ForEachIndex( - result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + result_shape, [&](absl::Span output_index) { for (int64 i = 0; i < dimensions.size(); ++i) { scratch_source_index[i] = output_index[dimensions[i]]; } @@ -632,7 +628,7 @@ StatusOr> LiteralBase::Broadcast( } StatusOr> LiteralBase::Reshape( - tensorflow::gtl::ArraySlice dimensions) const { + absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } @@ -654,14 +650,14 @@ StatusOr> LiteralBase::Reshape( return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); + ShapeUtil::HumanString(shape()), + ShapeUtil::HumanString(output->shape())); } return std::move(output); } std::unique_ptr LiteralBase::Transpose( - tensorflow::gtl::ArraySlice permutation) const { + absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -691,7 +687,7 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = MakeUnique(permuted_shape); + auto new_literal = absl::make_unique(permuted_shape); DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); @@ -700,12 +696,11 @@ std::unique_ptr LiteralBase::Transpose( template std::unique_ptr LiteralBase::SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const { - auto result_literal = MakeUnique(result_shape); + const Shape& result_shape, absl::Span start_indices) const { + auto result_literal = absl::make_unique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } @@ -716,8 +711,8 @@ std::unique_ptr LiteralBase::SliceInternal( } std::unique_ptr LiteralBase::Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const { + absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -756,12 +751,12 @@ Literal LiteralBase::Clone() const { } std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = MakeUnique(shape()); + auto result = absl::make_unique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, +string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); @@ -858,7 +853,7 @@ string LiteralBase::GetSparseElementAsString( } StatusOr LiteralBase::GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const { + absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -874,9 +869,8 @@ StatusOr LiteralBase::GetIntegralAsS64( case U64: return Get(multi_index); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } } @@ -901,8 +895,8 @@ size_t LiteralBase::Hash() const { return hash_value; } -Status MutableLiteralBase::SetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index, int64 value) { +Status MutableLiteralBase::SetIntegralAsS64(absl::Span multi_index, + int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: @@ -924,14 +918,13 @@ Status MutableLiteralBase::SetIntegralAsS64( Set(multi_index, value); break; default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } return Status::OK(); } -tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( +absl::Span LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1000,7 +993,7 @@ void LiteralBase::Piece::SortSparseElementsInternal() { auto values = data(); CHECK_LE(num_elements, values.size()); sparse_indices()->SortWithValues( - tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); + absl::Span(values.data(), num_elements)); } namespace { @@ -1029,9 +1022,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, element_index.push_back(i); std::vector element_pieces; ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); pieces->push_back("\n)"); return; } @@ -1055,8 +1048,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(": "); } else { pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); pieces->push_back("]: "); } pieces->push_back(literal.GetSparseElementAsString(i)); @@ -1067,8 +1059,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, CHECK(LayoutUtil::IsDenseArray(subshape)); - auto element_to_string = - [&](tensorflow::gtl::ArraySlice indices) -> string { + auto element_to_string = [&](absl::Span indices) -> string { PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. @@ -1117,9 +1108,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { pieces->push_back(" {"); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { @@ -1137,11 +1128,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { pieces->push_back(" {"); for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { @@ -1163,7 +1154,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {"); literal.EachCellAsString( - [&](tensorflow::gtl::ArraySlice indices, const string& value) { + [&](absl::Span indices, const string& value) { pieces->push_back(" "); pieces->push_back(value); }); @@ -1182,11 +1173,11 @@ string LiteralBase::ToString(bool print_layout) const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } void LiteralBase::EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -1195,7 +1186,7 @@ void LiteralBase::EachCellAsString( shape(), /*linear_index=*/0); do { per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } namespace { @@ -1203,7 +1194,7 @@ template std::unique_ptr ConvertBetweenNativeTypesWithConverter( const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); @@ -1249,14 +1240,12 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { template std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = MakeUnique( + auto result_literal = absl::make_unique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; - tensorflow::gtl::ArraySlice src_data = - src_literal.data(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->data(); + absl::Span src_data = src_literal.data(); + absl::Span dest_data = result_literal->data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1313,10 +1302,9 @@ StatusOr> ConvertIfDestTypeMatches( default: break; } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } StatusOr> ConvertSwitch( @@ -1345,11 +1333,10 @@ StatusOr> ConvertSwitch( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } } @@ -1367,8 +1354,8 @@ StatusOr> LiteralBase::BitcastConvert( return InvalidArgument( "Cannot bitcast convert from %s to %s, bit widths are different: %d != " "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), + PrimitiveType_Name(shape().element_type()), + PrimitiveType_Name(primitive_dest_type), primitive_util::BitWidth(shape().element_type()), primitive_util::BitWidth(primitive_dest_type)); } @@ -1396,13 +1383,13 @@ StatusOr> LiteralBase::ConvertToShape( element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); elements.push_back(std::move(*new_element)); } - auto converted = MakeUnique(); - *converted = MutableLiteralBase::MoveIntoTuple(&elements); + auto converted = absl::make_unique(); + *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); return std::move(converted); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const Literal& element : elements) { element_shapes.push_back(element.shape()); @@ -1435,6 +1422,12 @@ bool LiteralBase::Piece::EqualElementsInternal( bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + if (ShapeUtil::Equal(subshape(), other.subshape()) && + LayoutUtil::IsDenseArray(subshape())) { + CHECK_EQ(size_bytes(), other.size_bytes()); + return memcmp(buffer(), other.buffer(), size_bytes()) == 0; + } + std::vector multi_index; switch (subshape().element_type()) { case PRED: @@ -1487,7 +1480,7 @@ bool LiteralBase::operator==(const LiteralBase& other) const { namespace { template -static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, +static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64 i = 0; i < data.size(); ++i) { if (data[i] != value) { @@ -1686,7 +1679,62 @@ bool LiteralBase::IsAllFirst() const { }); } -bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsR1Iota() const { + if (!ShapeUtil::IsArray(shape())) { + return false; + } + + if (ShapeUtil::Rank(shape()) != 1) { + return false; + } + + auto is_iota_at_idx = [&](const int64 idx) { + switch (shape().element_type()) { + case U8: + return Get({idx}) == idx; + case U16: + return Get({idx}) == idx; + case U32: + return Get({idx}) == idx; + case U64: + return Get({idx}) == idx; + case S8: + return Get({idx}) == idx; + case S16: + return Get({idx}) == idx; + case S32: + return Get({idx}) == idx; + case S64: + return Get({idx}) == idx; + case F32: + return Get({idx}) == idx; + case F64: + return Get({idx}) == idx; + case F16: + return Get({idx}) == static_cast(idx); + case BF16: + return Get({idx}) == static_cast(idx); + case C64: + return Get({idx}) == complex64(idx, 0.0f); + case PRED: + return Get({idx}) == idx; + // token, opaque, tuple, etc. are all not iota. + default: + return false; + } + }; + + const int64 elements = ShapeUtil::ElementsIn(shape()); + for (int64 idx = 0; idx < elements; ++idx) { + if (!is_iota_at_idx(idx)) { + return false; + } + } + + return true; +} + +bool LiteralBase::IsZero(absl::Span indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1722,7 +1770,7 @@ namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const tensorflow::gtl::ArraySlice src) { + const absl::Span src) { *dest = RepeatedFieldT(src.begin(), src.end()); } @@ -1800,7 +1848,7 @@ void* LiteralBase::Piece::untyped_data() { namespace { template -Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, +Status CopyFromRepeatedField(absl::Span dest, const RepeatedFieldT& src) { if (dest.size() != src.size()) { return InvalidArgument( @@ -1956,7 +2004,7 @@ MutableLiteralBase::~MutableLiteralBase() {} MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableBorrowingLiteral& literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1967,7 +2015,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( const MutableBorrowingLiteral& literal) { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1981,7 +2029,7 @@ MutableBorrowingLiteral& MutableBorrowingLiteral::operator=( MutableBorrowingLiteral::MutableBorrowingLiteral( const MutableLiteralBase& literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal.shape()); + shape_ = absl::make_unique(literal.shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -1992,7 +2040,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) : MutableLiteralBase() { - shape_ = MakeUnique(literal->shape()); + shape_ = absl::make_unique(literal->shape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2004,7 +2052,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral(MutableLiteralBase* literal) MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral literal, const ShapeIndex& view_root) : MutableLiteralBase() { - shape_ = MakeUnique(literal.piece(view_root).subshape()); + shape_ = absl::make_unique(literal.piece(view_root).subshape()); CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = new Piece(); @@ -2016,7 +2064,7 @@ MutableBorrowingLiteral::MutableBorrowingLiteral( MutableBorrowingLiteral::MutableBorrowingLiteral(const char* src_buf_ptr, const Shape& shape) : MutableLiteralBase() { - shape_ = MakeUnique(shape); + shape_ = absl::make_unique(shape); CHECK(LayoutUtil::HasLayout(*shape_)); CHECK(!ShapeUtil::IsTuple(*shape_)); @@ -2061,7 +2109,7 @@ void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { } BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsArray(*shape_)); CHECK(LayoutUtil::HasLayout(*shape_)); @@ -2070,9 +2118,9 @@ BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) root_piece_.set_subshape(shape_.get()); } -BorrowingLiteral::BorrowingLiteral( - tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(MakeUnique(shape)) { +BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, + const Shape& shape) + : LiteralBase(), shape_(absl::make_unique(shape)) { CHECK(ShapeUtil::IsTuple(*shape_)); CHECK(!ShapeUtil::IsNestedTuple(*shape_)); CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 92c0f903cbe252a153103aa8514bb5531696bbfe..b928cb637494dec220a0912fdea96ed25cde13ef 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -25,13 +25,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -40,8 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -70,13 +70,12 @@ class LiteralBase { // Serialize to proto. LiteralProto ToProto() const; - // Returns an ArraySlice of the array for this literal for the given NativeT + // Returns a Span of the array for this literal for the given NativeT // (e.g., float). CHECKs if the subshape of the literal at the given // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type // to native type. template - tensorflow::gtl::ArraySlice data( - const ShapeIndex& shape_index = {}) const; + absl::Span data(const ShapeIndex& shape_index = {}) const; // Returns a const pointer to the sparse index array. Returns nullptr if the // literal is not a sparse array. @@ -100,12 +99,12 @@ class LiteralBase { // Gets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, + NativeT Get(absl::Span multi_index, const ShapeIndex& shape_index) const; // Overloads of Get for array literals. CHECKs if the literal is not // array-shaped and dense. template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; + NativeT Get(absl::Span multi_index) const; // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -114,7 +113,7 @@ class LiteralBase { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, + string GetAsString(absl::Span multi_index, const ShapeIndex& shape_index = {}) const; // As GetSparseElement(), but determines the correct type and converts the // value into text. @@ -122,14 +121,13 @@ class LiteralBase { const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; + StatusOr GetIntegralAsS64(absl::Span multi_index) const; // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in // the sparse array's list of (index, value) pairs, and is checked against the // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( + absl::Span GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; // Returns the value of the element in a sparse literal at the given sparse @@ -150,12 +148,12 @@ class LiteralBase { // // This literal must have a dense layout. void EachCellAsString( - const std::function indices, + const std::function indices, const string& value)>& per_cell) const; template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; + void EachCell( + std::function indices, NativeT value)> + per_cell) const; // Returns whether every element in this literal is equal to value. // @@ -195,9 +193,12 @@ class LiteralBase { // Literal consists entirely of the first element of the literal. bool IsAllFirst() const; + // Literal consists entirely of an iota. + bool IsR1Iota() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + bool IsZero(absl::Span indices) const; // Returns the count of the elements in the array at the given shape index in // this literal. @@ -270,13 +271,12 @@ class LiteralBase { // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. StatusOr> Reshape( - tensorflow::gtl::ArraySlice dimensions) const; + absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. StatusOr> Broadcast( - const Shape& result_shape, - tensorflow::gtl::ArraySlice dimensions) const; + const Shape& result_shape, absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,8 +285,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose( - tensorflow::gtl::ArraySlice permutation) const; + std::unique_ptr Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -294,9 +293,8 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice( - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) const; + std::unique_ptr Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this @@ -312,7 +310,7 @@ class LiteralBase { // Note: It's an antipattern to use this method then immediately call // MutableLiteralBase::Populate on the result (since that results in zero // initialization, then reinitialization. Conside if a call to - // MakeUnique(shape), followed by the call to + // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. static std::unique_ptr CreateFromShape(const Shape& shape); @@ -325,9 +323,9 @@ class LiteralBase { // Returns the buffer holding the array data for this piece as an array // slice. This piece must be array-shaped. template - tensorflow::gtl::ArraySlice data() const; + absl::Span data() const; template - tensorflow::gtl::MutableArraySlice data(); + absl::Span data(); // Returns the buffer holding the array data for this piece as a void*. This // piece must be array-shaped. @@ -338,9 +336,9 @@ class LiteralBase { // is CHECKed against the dimension sizes of the array. This piece must be // array-shaped. template - NativeT Get(tensorflow::gtl::ArraySlice index) const; + NativeT Get(absl::Span index) const; template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); + void Set(absl::Span index, NativeT value); // Gets/sets the buffer holding the array data. char* buffer() const { return buffer_; } @@ -542,8 +540,7 @@ class LiteralBase { private: template std::unique_ptr SliceInternal( - const Shape& result_shape, - tensorflow::gtl::ArraySlice start_indices) const; + const Shape& result_shape, absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -551,13 +548,12 @@ class MutableLiteralBase : public LiteralBase { public: virtual ~MutableLiteralBase() = 0; - // Returns a MutableArraySlice view of the array for this literal for the + // Returns a Span view of the array for this literal for the // given NativeT (e.g., float). CHECKs if the subshape of the literal at the // given ShapeIndex is not array. See primitive_util.h for the mapping from // XLA type to native type. template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); + absl::Span data(const ShapeIndex& shape_index = {}); // Unhide const method from parent class. using LiteralBase::data; @@ -584,8 +580,7 @@ class MutableLiteralBase : public LiteralBase { // are populated. template void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); + absl::Span values, bool sort = true); // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted @@ -606,39 +601,38 @@ class MutableLiteralBase : public LiteralBase { // corresponding base indices being 0. // This literal and 'src_literal' must be arrays. Status CopySliceFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Copies one element from src_literal[src_index] to (*this)[dest_index]. Status CopyElementFrom(const LiteralSlice& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); + absl::Span src_index, + absl::Span dest_index); // Sets an element in the literal at the given index. The multi_index is // CHECKed against the dimension sizes. template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); + void Set(absl::Span multi_index, const ShapeIndex& shape_index, + NativeT value); // Overloads of Set for array literals. CHECKs if the literal is not // array-shaped and dense. template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + void Set(absl::Span multi_index, NativeT value); // Appends the given element to the literal. If the elements are not appended // in sorted order, then SortSparseElements should be called before calling // other methods. This literal must have a sparse layout. template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); + void AppendSparseElement(absl::Span multi_index, NativeT value, + const ShapeIndex& shape_index = {}); // Sorts the elements in a sparse array. void SortSparseElements(const ShapeIndex& shape_index = {}); // As Set(), but truncates `value` to the literal element type before storing. // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); + Status SetIntegralAsS64(absl::Span multi_index, int64 value); // Populate this literal with the given values. Examples: // @@ -653,7 +647,7 @@ class MutableLiteralBase : public LiteralBase { // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 // array of S32. template - void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(absl::Span values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); @@ -670,7 +664,7 @@ class MutableLiteralBase : public LiteralBase { // in this literal object. // // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // NativeT(absl::Span indexes) or compatible. // // This literal must have a dense layout. template @@ -690,8 +684,7 @@ class MutableLiteralBase : public LiteralBase { // moved into the tuple elements of a new tuple-shaped Literal which is // returned. Upon return, each of the Literals in 'elements' is set to a nil // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. static StatusOr> CreateFromProto( @@ -709,20 +702,20 @@ class MutableLiteralBase : public LiteralBase { // arguments one by one. template Status CopySliceFromInternal(const LiteralBase& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + absl::Span src_base, + absl::Span dest_base, + absl::Span copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. struct StrideConfig { StrideConfig(const Shape& source_shape, const Shape& dest_shape, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // The dimensions of the stride operation. Essentially every dimension // will be iterated from base[i] to base[i]+dimensions[i], in step[i] // steps. - tensorflow::gtl::ArraySlice dimensions; + absl::Span dimensions; DimensionVector base; DimensionVector step; int64 minor_dimension = 0; @@ -851,7 +844,7 @@ class BorrowingLiteral : public LiteralBase { // This constructor is only used for array shapes. BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); // Similar as above, except to be used for constructing non-nested tuples. - BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + BorrowingLiteral(absl::Span src_buf_ptrs, const Shape& shape); // TODO(b/79707221): adding constructors for nested tuples as well. @@ -871,7 +864,7 @@ class BorrowingLiteral : public LiteralBase { }; template -tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { +absl::Span LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -879,12 +872,12 @@ tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { +absl::Span LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -892,20 +885,19 @@ tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) << " type, but literal element type is " << PrimitiveType_Name(subshape().element_type()); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(buffer()), element_count()); + return absl::Span(reinterpret_cast(buffer()), + element_count()); } template -NativeT LiteralBase::Piece::Get( - tensorflow::gtl::ArraySlice multi_index) const { +NativeT LiteralBase::Piece::Get(absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)]; } template -void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, +void LiteralBase::Piece::Set(absl::Span multi_index, NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -913,39 +905,37 @@ void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, } template -tensorflow::gtl::ArraySlice LiteralBase::data( +absl::Span LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } template -tensorflow::gtl::MutableArraySlice MutableLiteralBase::data( - const ShapeIndex& shape_index) { +absl::Span MutableLiteralBase::data(const ShapeIndex& shape_index) { return piece(shape_index).data(); } template -inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, +inline NativeT LiteralBase::Get(absl::Span multi_index, const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT LiteralBase::Get( - tensorflow::gtl::ArraySlice multi_index) const { +inline NativeT LiteralBase::Get(absl::Span multi_index) const { return root_piece().Get(multi_index); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + const ShapeIndex& shape_index, + NativeT value) { return piece(shape_index).Set(multi_index, value); } template -inline void MutableLiteralBase::Set( - tensorflow::gtl::ArraySlice multi_index, NativeT value) { +inline void MutableLiteralBase::Set(absl::Span multi_index, + NativeT value) { return root_piece().Set(multi_index, value); } @@ -964,7 +954,7 @@ NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, template void MutableLiteralBase::AppendSparseElement( - tensorflow::gtl::ArraySlice multi_index, NativeT value, + absl::Span multi_index, NativeT value, const ShapeIndex& shape_index) { Piece& p = piece(shape_index); const Shape& subshape = p.subshape(); @@ -980,8 +970,7 @@ void MutableLiteralBase::AppendSparseElement( template void LiteralBase::EachCell( - std::function indices, - NativeT value)> + std::function indices, NativeT value)> per_cell) const { if (ShapeUtil::IsZeroElementArray(shape())) { return; @@ -989,12 +978,11 @@ void LiteralBase::EachCell( std::vector indices(ShapeUtil::Rank(shape()), 0); do { per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); + } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); } template -inline void MutableLiteralBase::PopulateR1( - tensorflow::gtl::ArraySlice values) { +inline void MutableLiteralBase::PopulateR1(absl::Span values) { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); @@ -1039,8 +1027,9 @@ void MutableLiteralBase::PopulateFromArray(const Array& values) { for (int dim = 0; dim < values.num_dimensions(); ++dim) { CHECK_EQ(values.dim(dim), shape().dimensions(dim)); } - values.Each([this](tensorflow::gtl::ArraySlice indices, - NativeT value) { this->Set(indices, value); }); + values.Each([this](absl::Span indices, NativeT value) { + this->Set(indices, value); + }); } template @@ -1059,9 +1048,9 @@ void MutableLiteralBase::PopulateR4FromArray4D(const Array4D& values) { } template -void MutableLiteralBase::PopulateSparse( - SparseIndexArray indices, tensorflow::gtl::ArraySlice values, - bool sort) { +void MutableLiteralBase::PopulateSparse(SparseIndexArray indices, + absl::Span values, + bool sort) { CHECK(LayoutUtil::IsSparseArray(shape())); int rank = ShapeUtil::Rank(shape()); CHECK_EQ(indices.rank(), rank); @@ -1071,7 +1060,7 @@ void MutableLiteralBase::PopulateSparse( CHECK_LE(num_elements, max_elements); CHECK_EQ(num_elements, indices.index_count()); auto root_data = root_piece().data(); - // Piece::data() returns an ArraySlice of size equal to the number of indices + // Piece::data() returns a Span of size equal to the number of indices // in the SparseIndexArray. So there is no need to adjust the size of the data // here. It is enough to just copy the incoming values into the data buffer. std::copy(values.begin(), values.end(), root_data.begin()); @@ -1091,14 +1080,14 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice literal_data = data(); + absl::Span literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); int64 minor_dimension_size = ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); - auto init_function = [&](tensorflow::gtl::ArraySlice indexes) { + auto init_function = [&](absl::Span indexes) { DimensionVector minor_scan_indexes(rank, 0); const int64 index = IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); @@ -1116,7 +1105,7 @@ Status MutableLiteralBase::PopulateInternal(const FnType& generator, ShapeUtil::ForEachIndex( this_shape, stride_config.base, stride_config.dimensions, stride_config.step, - [&init_function](tensorflow::gtl::ArraySlice indexes) { + [&init_function](absl::Span indexes) { init_function(indexes); return true; }); @@ -1154,15 +1143,15 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = - MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); + auto literal = absl::make_unique( + ShapeUtil::MakeShape(shape().element_type(), bounds)); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; } DimensionVector output_indices(bounds.size(), 0); - tensorflow::gtl::ArraySlice input_indices = output_indices; + absl::Span input_indices = output_indices; input_indices.remove_prefix(1); bool done = false; diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 94993cc87443ba8c22fd7c2eacfc8756d3f48edc..3d8725ed7051cafc97987f25a96004fa876dfdd3 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" -using tensorflow::strings::Appendf; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; namespace xla { namespace literal_comparison { @@ -38,7 +38,8 @@ namespace { // between the left-hand-side and right-hand-side, by bit-casting to UnsignedT // -- on miscompare, a nice error message is given in the AssertionFailure. template -Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs, + absl::Span multi_index) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); auto lhs_double = static_cast(lhs); @@ -46,9 +47,10 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + "was requested: %s=%g=%a vs %s=%g=%a at array index %s", + StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, + StrCat(absl::Hex(urhs)), rhs_double, rhs_double, + LiteralUtil::MultiIndexAsString(multi_index)); } return Status::OK(); } @@ -57,39 +59,47 @@ Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { // bitwise helper above (this is the un-specialized fallback, to just use the // default gunit implementation). template -Status CompareEqual(NativeT lhs, NativeT rhs) { +Status CompareEqual(NativeT lhs, NativeT rhs, + absl::Span multi_index) { if (lhs == rhs) { return Status::OK(); } - return InvalidArgument("Expected equality of these values:\n %s\n %s", - StrCat(lhs).c_str(), StrCat(rhs).c_str()); + return InvalidArgument( + "first mismatch at array index %s:\n expected value: %s\n actual " + "value: %s", + LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> -Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(bfloat16 lhs, bfloat16 rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(Eigen::half lhs, Eigen::half rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(float lhs, float rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); +Status CompareEqual(double lhs, double rhs, + absl::Span multi_index) { + return CompareFloatsBitwiseEqual(lhs, rhs, multi_index); } template <> -Status CompareEqual(complex64 lhs, complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); +Status CompareEqual(complex64 lhs, complex64 rhs, + absl::Span multi_index) { + auto res = CompareEqual(lhs.real(), rhs.real(), multi_index); if (!res.ok()) { return res; } - return CompareEqual(lhs.imag(), rhs.imag()); + return CompareEqual(lhs.imag(), rhs.imag(), multi_index); } // A recursive function which iterates through every index of expected and @@ -97,18 +107,18 @@ Status CompareEqual(complex64 lhs, complex64 rhs) { // elements are equal. template Status Equal(LiteralSlice expected, LiteralSlice actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { + absl::Span multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { NativeT expected_value = expected.Get(multi_index); NativeT actual_value = actual.Get(multi_index); - return CompareEqual(expected_value, actual_value); + return CompareEqual(expected_value, actual_value, multi_index); } Status result; for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index[dimension] = i; - result.Update(Equal(expected, actual, multi_index, dimension + 1)); + TF_RETURN_IF_ERROR( + Equal(expected, actual, multi_index, dimension + 1)); } return result; } @@ -152,15 +162,26 @@ bool NanMismatch(half expected, half actual, bool relaxed_nans) { static_cast(actual), relaxed_nans); } +// Returns whether the given value is infinity. +template +bool IsInf(NativeT val) { + return std::isinf(val); +} + +template <> +bool IsInf(half val) { + return std::isinf(static_cast(val)); +} + // Converts the given floating-point value to a string. template string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); + return absl::StrFormat("%8.4g", static_cast(value)); } template <> string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } // Returns the absolute value of the given floating point value. This function @@ -215,13 +236,12 @@ class NearComparator { } string ToString(const Shape& shape) const { - return Printf( + return absl::StrFormat( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + FpValueToString(actual), FpValueToString(expected), LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), + linear_index)), rel_error, abs_error); } }; @@ -240,17 +260,12 @@ class NearComparator { // Runs the comparison between expected and actual literals. Status Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, ToStringTruncated(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, ToStringTruncated(actual_)); - // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); if (!ShapeUtil::IsArray(expected_.shape())) { return InvalidArgument("Expected array shape; got %s.", - ShapeUtil::HumanString(expected_.shape()).c_str()); + ShapeUtil::HumanString(expected_.shape())); } mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); @@ -263,7 +278,7 @@ class NearComparator { } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { miscompare_callback_(expected_, actual_, mismatches_); } - return InvalidArgument("%s", ErrorMessage().c_str()); + return InvalidArgument("%s", ErrorMessage()); } // Insert the given absolute value into the absolute value bucket vector. The @@ -288,8 +303,7 @@ class NearComparator { } // Insert the given error into the given error bucket vector. - void UpdateErrorBucket( - float error, tensorflow::gtl::MutableArraySlice error_buckets) { + void UpdateErrorBucket(float error, absl::Span error_buckets) { CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < error_buckets.size(); ++i) { if (error >= kErrorBucketBounds[i]) { @@ -300,12 +314,13 @@ class NearComparator { // Compares the two given elements from the expected and actual literals at // the given literal_index and keeps track of various mismatch statistics. - void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + template + void CompareValues(T expected, T actual, int64 linear_index) { const bool is_nan_mismatch = NanMismatch(expected, actual, error_.relaxed_nans); float abs_error; float rel_error; - if (actual == expected) { + if (CompareEqual(expected, actual, {linear_index}).ok()) { abs_error = 0; rel_error = 0; } else if (is_nan_mismatch) { @@ -316,6 +331,12 @@ class NearComparator { // weak ordering requirement of std containers. abs_error = std::numeric_limits::infinity(); rel_error = std::numeric_limits::infinity(); + } else if (IsInf(expected) || IsInf(actual)) { + // If either the expected or actual value is infinity but not both, + // then both absolute and relative error are regarded as inifity. + CHECK(!CompareEqual(expected, actual, {linear_index}).ok()); + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); } else { abs_error = FpAbsoluteValue(actual - expected); rel_error = abs_error / FpAbsoluteValue(expected); @@ -329,11 +350,11 @@ class NearComparator { // bound is exceeded and vice versa. if (is_abs_mismatch) { num_abs_mismatches_++; - UpdateErrorBucket(rel_error, &rel_error_buckets_); + UpdateErrorBucket(rel_error, absl::MakeSpan(rel_error_buckets_)); } if (is_rel_mismatch) { num_rel_mismatches_++; - UpdateErrorBucket(abs_error, &abs_error_buckets_); + UpdateErrorBucket(abs_error, absl::MakeSpan(abs_error_buckets_)); } UpdateAbsValueBucket(actual, is_mismatch); @@ -358,15 +379,36 @@ class NearComparator { mismatches_.data()[linear_index] = true; } + // For complex64 types, we compare real and imaginary parts individually. + void CompareValues(complex64 expected, complex64 actual, int64 linear_index) { + bool mismatch = false; + CompareValues(expected.real(), actual.real(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for real part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + CompareValues(expected.imag(), actual.imag(), linear_index); + if (mismatches_.data()[linear_index] == true) { + mismatch = true; + // Delay the mismatch count increase for imag part, instead increase + // mismatch by 1 for the entire complex number. + num_mismatches_--; + } + if (mismatch == true) { + num_mismatches_++; + } + mismatches_.data()[linear_index] = mismatch; + } + // Compares the two literals elementwise. void CompareLiterals() { // Fast path optimization for the case were layouts match. if (LayoutUtil::Equal(actual_.shape().layout(), expected_.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected_.data(); - tensorflow::gtl::ArraySlice actual_data = - actual_.data(); + absl::Span expected_data = expected_.data(); + absl::Span actual_data = actual_.data(); const int64 len = expected_data.size(); for (int64 i = 0; i < len; ++i) { CompareValues(expected_data[i], actual_data[i], i); @@ -402,23 +444,23 @@ class NearComparator { auto percent_string = [](float a, float b) { float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); + return absl::StrFormat("%0.4f%%", pct); }; - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + StrAppendFormat( + &out, + "\nMismatch count %d (%s) in shape %s (%d elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, percent_string(num_mismatches_, element_count), + ShapeUtil::HumanString(actual_.shape()), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); if (num_nan_mismatches_ > 0) { StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); } - Appendf(&out, "Top relative error mismatches:\n"); + StrAppendFormat(&out, "Top relative error mismatches:\n"); for (auto it = top_rel_mismatches_.rbegin(); it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + StrAppend(&out, " ", it->ToString(actual_.shape()), "\n"); } if (!detailed_message_) { @@ -430,36 +472,37 @@ class NearComparator { for (int i = 0; i < abs_value_buckets_.size(); ++i) { const int64 bucket_size = abs_value_buckets_[i].first; const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); + string mismatch_str = + bucket_mismatches > 0 + ? absl::StrFormat(", mismatches %d", bucket_mismatches) + : ""; + StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count), + mismatch_str); } auto print_accum_buckets = [&](const string& header, int64 total, - tensorflow::gtl::ArraySlice buckets) { + absl::Span buckets) { StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); + StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total)); CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); + StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total)); } }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count)); print_accum_buckets( "Relative error breakdown of elements exceeding abs error bound", num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count)); print_accum_buckets( "Absolute error breakdown of elements exceeding rel error bound", num_rel_mismatches_, abs_error_buckets_); @@ -528,6 +571,63 @@ constexpr std::array NearComparator::kAbsValueBucketBounds; template constexpr std::array NearComparator::kErrorBucketBounds; +Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + auto index = absl::MakeSpan(multi_index); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, index, 0); + break; + case U8: + result = Equal(expected, actual, index, 0); + break; + case S32: + result = Equal(expected, actual, index, 0); + break; + case S64: + result = Equal(expected, actual, index, 0); + break; + case U32: + result = Equal(expected, actual, index, 0); + break; + case U64: + result = Equal(expected, actual, index, 0); + break; + case BF16: + result = Equal(expected, actual, index, 0); + break; + case F16: + result = Equal(expected, actual, index, 0); + break; + case F32: + result = Equal(expected, actual, index, 0); + break; + case F64: + result = Equal(expected, actual, index, 0); + break; + case C64: + result = Equal(expected, actual, index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update(EqualHelper(LiteralSlice(expected, {i}), + LiteralSlice(actual, {i}))); + } + break; + } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); + default: + LOG(FATAL) << "Unsupported primitive type: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + return result; +} + // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. @@ -544,17 +644,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); - Status res = + Status element_result = NearHelper(expected_element, actual_element, error, detailed_message, miscompare_callback, element_index); - if (!res.ok()) { - string err_message = Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), - res.error_message().c_str()); + if (!element_result.ok()) { + element_result = InvalidArgument("Array at shape index %s, %s", + element_index.ToString(), + element_result.error_message()); if (return_status.ok()) { - return_status = res; + return_status = element_result; } else { - return_status = AppendStatus(return_status, res.error_message()); + return_status = + AppendStatus(return_status, element_result.error_message()); } } } @@ -562,10 +663,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, // Emit a top-level error message containing the top-level shape in case // of mismatch. int64 total_elements = RecursiveElementCount(actual.shape()); - return_status = InvalidArgument( - "\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, - return_status.error_message().c_str()); + return_status = + InvalidArgument("\nMismatches in shape %s (%d elements):\n%s", + ShapeUtil::HumanString(actual.shape()), + total_elements, return_status.error_message()); } return return_status; } @@ -600,8 +701,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } } - // Non-floating point literal. - return literal_comparison::Equal(expected, actual); + // Non-floating point, non-tuple literal. + return EqualHelper(expected, actual); } } // namespace @@ -609,14 +710,14 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.element_type() != actual.element_type()) { return InvalidArgument("element type mismatch, want: %s got %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (ShapeUtil::IsTuple(expected)) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( - "want tuple element count: %lld got tuple element count: %lld", + "want tuple element count: %d got tuple element count: %d", ShapeUtil::TupleElementCount(expected), ShapeUtil::TupleElementCount(actual)); } @@ -630,14 +731,13 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (expected.element_type() != actual.element_type()) { - return InvalidArgument( - "mismatch in primitive type %s vs %s", - PrimitiveType_Name(expected.element_type()).c_str(), - PrimitiveType_Name(actual.element_type()).c_str()); + return InvalidArgument("mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()), + PrimitiveType_Name(actual.element_type())); } if (expected.dimensions_size() != actual.dimensions_size()) { return InvalidArgument("want dimensions_size %d got dimensions_size %d", @@ -648,8 +748,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.dimensions(i) != actual.dimensions(i)) { return InvalidArgument( "mismatch in dimension #%d expected: %s actual: %s", i, - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } } } @@ -657,83 +756,43 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return Status::OK(); } +namespace { + +// If result is an error, extend the error message with the expected and actual +// literals. +Status EmitLiteralsInErrorMessage(const Status& result, + const LiteralSlice& expected, + const LiteralSlice& actual) { + if (result.ok()) { + return result; + } + return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", + result.error_message(), ToStringTruncated(expected), + ToStringTruncated(actual)); +} + +} // namespace + Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - std::vector multi_index(expected.shape().dimensions_size(), 0); - Status result; - switch (expected.shape().element_type()) { - case PRED: - result = Equal(expected, actual, &multi_index, 0); - break; - case U8: - result = Equal(expected, actual, &multi_index, 0); - break; - case S32: - result = Equal(expected, actual, &multi_index, 0); - break; - case S64: - result = Equal(expected, actual, &multi_index, 0); - break; - case U32: - result = Equal(expected, actual, &multi_index, 0); - break; - case U64: - result = Equal(expected, actual, &multi_index, 0); - break; - case BF16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F32: - result = Equal(expected, actual, &multi_index, 0); - break; - case F64: - result = Equal(expected, actual, &multi_index, 0); - break; - case C64: - result = Equal(expected, actual, &multi_index, 0); - break; - case TUPLE: { - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - result.Update( - Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); - } - break; - } - case TOKEN: - // Tokens have no on-device representation and are trivially equal. - return Status::OK(); - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - - if (result.ok()) { - return Status::OK(); - } - - return AppendStatus(result, - tensorflow::strings::Printf( - "\nat index: %s\nexpected: %s\nactual: %s", - LiteralUtil::MultiIndexAsString(multi_index).c_str(), - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str())); + Status result = EqualHelper(expected, actual); + return EmitLiteralsInErrorMessage(result, expected, actual); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error, bool detailed_message, const MiscompareCallback& miscompare_callback) { - return NearHelper(expected, actual, error, detailed_message, - miscompare_callback, - /*shape_index=*/{}); + VLOG(1) << "Expected literal:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "Actual literal:"; + XLA_VLOG_LINES(1, actual.ToString()); + Status result = + NearHelper(expected, actual, error, detailed_message, miscompare_callback, + /*shape_index=*/{}); + return EmitLiteralsInErrorMessage(result, expected, actual); } string ToStringTruncated(const LiteralSlice& literal) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index e8f919950f0efc8b508f7ad4aee5233176bc0abd..1a64594db86af31dcc196725d4b4f2a3ad9e4746 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -33,7 +36,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::ArraySlice; using ::testing::ElementsAre; using ::testing::HasSubstr; @@ -96,42 +98,42 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit->ToString()); auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit->ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit->ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit->ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit->ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit->ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit->ToString()); - // 3.14 will be truncated to 3.125 in bfloat16 format. + // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated->ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - ASSERT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -141,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -155,7 +157,7 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, TupleToString) { @@ -169,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - ASSERT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple->ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -195,7 +197,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { { 9, 10 }, { 11, 12 } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, CreateSparse) { @@ -219,9 +221,9 @@ TEST_F(LiteralUtilTest, CreateSparse) { std::vector expected_values = {8, 9, 7, 10}; EXPECT_EQ(literal->sparse_indices()->data(), - ArraySlice(expected_indices.data(), - expected_indices.num_elements())); - EXPECT_EQ(literal->data(), ArraySlice(expected_values)); + absl::Span(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal->data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -248,7 +250,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { @@ -281,7 +283,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, EachCellR2F32) { @@ -293,7 +295,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format on std::vector> seen; literal->EachCellAsString( - [&seen](ArraySlice indices, const string& value) { + [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -355,15 +357,15 @@ TEST_F(LiteralUtilTest, TokenEquality) { TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + auto colmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); colmajor->Set({0, 0}, 1.0); colmajor->Set({0, 1}, 2.0); colmajor->Set({1, 0}, 3.0); colmajor->Set({1, 1}, 4.0); - auto rowmajor = - MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + auto rowmajor = absl::make_unique( + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); rowmajor->Set({0, 0}, 1.0); rowmajor->Set({0, 1}, 2.0); rowmajor->Set({1, 0}, 3.0); @@ -646,7 +648,7 @@ TEST_F(LiteralUtilTest, TransposeR4) { // clang-format on auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](ArraySlice indices, float value) { + reshape->EachCell([&](absl::Span indices, float value) { EXPECT_EQ(value, original->Get( {indices[2], indices[3], indices[0], indices[1]})); }); @@ -886,7 +888,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; - auto init_proc = [&](ArraySlice indexes) { + auto init_proc = [&](absl::Span indexes) { source->Set(indexes, ++seqnr); return true; }; @@ -902,7 +904,7 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; - auto check_proc = [&](ArraySlice indexes) { + auto check_proc = [&](absl::Span indexes) { std::copy(indexes.begin(), indexes.end(), source_indexes.begin()); std::transform(source_indexes.begin(), source_indexes.end(), src_base, source_indexes.begin(), std::plus()); @@ -1036,7 +1038,7 @@ TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto vector = LiteralUtil::CreateR1({5.0, 7.0}); Status status = matrix->CopyFrom(*vector); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); } @@ -1089,8 +1091,8 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + auto literal = absl::make_unique(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1102,7 +1104,7 @@ TEST_F(LiteralUtilTest, Populate) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { + auto check_function = [&](absl::Span indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1131,8 +1133,8 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = MakeUnique(shape); - auto generator = [&](ArraySlice indexes) -> uint32 { + auto literal = absl::make_unique(shape); + auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), @@ -1144,7 +1146,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; - auto check_function = [&](ArraySlice indexes) { + auto check_function = [&](absl::Span indexes) { auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; @@ -1323,8 +1325,8 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), - "bit widths are different")); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1391,10 +1393,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { Literal::CreateFromProto(p)); auto r = literal->data(); ASSERT_EQ(4, r.size()); - ASSERT_EQ(h1, r[0]); - ASSERT_EQ(h2, r[1]); - ASSERT_EQ(h2, r[2]); - ASSERT_EQ(h1, r[3]); + EXPECT_EQ(h1, r[0]); + EXPECT_EQ(h2, r[1]); + EXPECT_EQ(h2, r[2]); + EXPECT_EQ(h1, r[3]); } TEST_F(LiteralUtilTest, LiteralSliceTest) { @@ -1558,7 +1560,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { )); - Literal literal = Literal::MoveIntoTuple(&elements); + Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); @@ -1577,7 +1579,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); - ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); + EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } TEST_F(LiteralUtilTest, LiteralMoveAssignment) { @@ -1690,7 +1692,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) { *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1702,7 +1704,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); } TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { @@ -1714,7 +1716,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1727,7 +1729,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { proto.add_f32s(3.0); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 84 elements in LiteralProto")); } @@ -1740,7 +1742,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { proto.add_s32s(100); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 elements in LiteralProto")); } @@ -1755,7 +1757,7 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); } TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { @@ -1771,7 +1773,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { @@ -1794,7 +1796,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, SortSparseElements) { @@ -1804,7 +1806,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal->AppendSparseElement({3, 4, 5}, 3.0); literal->AppendSparseElement({1, 2, 3}, 1.0); literal->SortSparseElements(); - ASSERT_EQ(literal->ToString(false), + EXPECT_EQ(literal->ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1812,27 +1814,26 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { std::vector dimensions = {10, 10, 10}; SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); - ASSERT_EQ( + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) ->GetSparseElementAsString(1), "false"); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ( + absl::StrCat(int64{2})); + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + absl::StrCat(double{2.0})); + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(static_cast(half{2.0}))); - ASSERT_EQ( - LiteralUtil::CreateSparse( - dimensions, indices, - std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); + absl::StrCat(static_cast(half{2.0}))); + EXPECT_EQ(LiteralUtil::CreateSparse( + dimensions, indices, + std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 5d33df7d40bf3bfcc8012ce1129d532b34555344..613449cf10c785de55e8474c0ee35f78e8ed92b4 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,19 +33,15 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; + // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template @@ -57,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -85,8 +84,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace /* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } @@ -102,7 +100,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::CreateToken() { - return MakeUnique(ShapeUtil::MakeTokenShape()); + return absl::make_unique(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { @@ -279,15 +277,15 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ std::unique_ptr LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique( + auto literal = absl::make_unique( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); literal->PopulateR1(values); return literal; } /* static */ std::unique_ptr LiteralUtil::CreateR1U8( - tensorflow::StringPiece value) { - auto literal = MakeUnique( + absl::string_view value) { + auto literal = absl::make_unique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { literal->Set({i}, value[i]); @@ -302,9 +300,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal) { + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; for (int64 i = 0; i < new_dimensions.size(); ++i) { new_num_elements *= new_dimensions[i]; @@ -312,7 +309,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = MakeUnique( + auto new_literal = absl::make_unique( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used @@ -431,12 +428,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::MakeTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } @@ -444,12 +442,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); } @@ -463,7 +462,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { for (const auto& element : elements) { element_shapes.push_back(element->shape()); } - auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + auto literal = + absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); @@ -472,8 +472,8 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ string LiteralUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); + absl::Span multi_index) { + return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e3737a9d0051b32dc0becc19e1849c856a50e52e..2d6084a67a3b966d054103df0f06ddb82d0d6525 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -27,6 +27,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -34,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -43,8 +45,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -71,8 +71,7 @@ class LiteralUtil { template static std::unique_ptr CreateR0(NativeT value); template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values); + static std::unique_ptr CreateR1(absl::Span values); static std::unique_ptr CreateR1( const tensorflow::core::Bitmap& values); template @@ -141,8 +140,8 @@ class LiteralUtil { // template static std::unique_ptr CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort = true); + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -157,7 +156,7 @@ class LiteralUtil { // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value); + absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear @@ -187,7 +186,7 @@ class LiteralUtil { const Array4D& values, const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value); + static std::unique_ptr CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. @@ -215,10 +214,10 @@ class LiteralUtil { // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); static std::unique_ptr MakeTupleFromSlices( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // @@ -259,8 +258,7 @@ class LiteralUtil { // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType primitive_type, absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, @@ -279,9 +277,8 @@ class LiteralUtil { // buffer of the input literal is assumed to have the given minor_to_major // layout order. static std::unique_ptr ReshapeSlice( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const LiteralSlice& literal); + absl::Span new_dimensions, + absl::Span minor_to_major, const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -291,7 +288,7 @@ class LiteralUtil { typename T = typename primitive_util::PrimitiveTypeToNative::type> static StatusOr> CreateRandomLiteral( const Shape& shape, - const std::function)>& generator); + const std::function)>& generator); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -319,15 +316,14 @@ class LiteralUtil { // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); + static string MultiIndexAsString(absl::Span multi_index); }; std::ostream& operator<<(std::ostream& out, const Literal& literal); template /* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShape( + auto literal = absl::make_unique(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); literal->Set({}, value); return literal; @@ -335,8 +331,8 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateR1( - tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique( + absl::Span values) { + auto literal = absl::make_unique( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); literal->PopulateR1(values); @@ -347,7 +343,7 @@ template /* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, @@ -427,15 +423,16 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort) { + absl::Span dimensions, SparseIndexArray indices, + absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); literal->PopulateSparse(indices, values, sort); return literal; } @@ -451,7 +448,7 @@ template template /* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); literal->PopulateFromArray(values); @@ -569,10 +566,11 @@ template template /* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); +LiteralUtil::CreateFullWithDescendingLayout(absl::Span dimensions, + NativeT value) { + auto literal = + absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); literal->PopulateWithValue(value); return literal; } @@ -581,14 +579,12 @@ template /* static */ StatusOr> LiteralUtil::CreateRandomLiteral( const Shape& shape, - const std::function)>& generator) { + const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); + [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } @@ -599,9 +595,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); + shape, + [&](absl::Span /*indexes*/) { return generator(*engine); }); } template diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 3c74e070da529b7f1431e01fbaf31932f582db44..fcff48b6b18ba115a67f3141a9aea4ca461be55d 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,7 +60,7 @@ MaybeFind(const Collection& collection, if (it == collection.end()) { std::ostringstream os; os << key; - return NotFound("key not found: %s", os.str().c_str()); + return NotFound("key not found: %s", os.str()); } return {it->second}; } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 69ef4f7a2f3ea559a334a11cbe8392b610742bab..4eab4fa4290c270697c00be20840cf4e85459183 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) { if (end_of_line == string::npos) { end_of_line = report.size(); } - tensorflow::StringPiece line(report.data() + pos, end_of_line - pos); + absl::string_view line(report.data() + pos, end_of_line - pos); // TODO(b/34779244): Figure out how to do this without the verbose log-line // prefix. The usual way didn't compile on open source. @@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() { if (text.empty()) { text = "[no category]"; } - tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ", - entry_name_, ")"); + absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_, + ")"); AppendTableRow(text, category.metric_sum, metric_sum); // Show the top entries in the category. @@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() { } const int64 remaining_categories = categories.size() - categories_shown; if (remaining_categories > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories, - " more categories)"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_categories, " more categories)"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() { } const int64 remaining_entries = entries_.size() - entries_shown; if (remaining_entries > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries, - " more ", entry_name_, ")"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() { string MetricTableReport::MetricString(double metric) { // Round to integer and stringify. - string s1 = tensorflow::strings::StrCat(std::llround(metric)); + string s1 = absl::StrCat(std::llround(metric)); // Code below commafies the string, e.g. "1234" becomes "1,234". - tensorflow::StringPiece sp1(s1); + absl::string_view sp1(s1); string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. @@ -263,8 +264,7 @@ string MetricTableReport::MetricString(double metric) { } string MetricTableReport::MetricPercent(double metric) { - return tensorflow::strings::Printf("%5.2f%%", - metric / expected_metric_sum_ * 100.0); + return absl::StrFormat("%5.2f%%", metric / expected_metric_sum_ * 100.0); } } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index 818fb1d3fe0b8bbe1a8eba363ff6445e2f3df9d2..062d8ed99b213535ad39d840aaaf10a6fe0da84c 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -18,9 +18,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -108,7 +107,7 @@ class MetricTableReport { // Append all parameters to the report. template void AppendLine(Args... args) { - tensorflow::strings::StrAppend(&report_, std::forward(args)..., "\n"); + absl::StrAppend(&report_, std::forward(args)..., "\n"); } // Represents a set of entries with the same category_text. diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 6b7fd10d63f8f97b0e0bf7570488c06323368d75..f9473d372bb15058d7413e2ac8a303dd34322180 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -54,17 +54,17 @@ StatusOr> PackedLiteralReader::Read( if (shape.element_type() != F32) { return Unimplemented( "not yet implemented element type for packed literal reading: %s", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } - auto result = MakeUnique(literal_shape); + auto result = absl::make_unique(literal_shape); result->PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - tensorflow::gtl::ArraySlice field = result->data(); - char* data = tensorflow::bit_cast(field.data()); + absl::Span field = result->data(); + char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + absl::string_view sp; auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h deleted file mode 100644 index bfcdfc62f9541ab09b94a48d5121e16bad4d43cd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/ptr_util.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ - -// As this was moved to tensorflow/core/util, provide indirections here to -// maintain current functionality of the library. - -#include - -#include -#include -#include - -#include "tensorflow/core/util/ptr_util.h" - -namespace xla { -using tensorflow::MakeUnique; -using tensorflow::WrapUnique; -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index c8f2d65c223ccfe20862954c224d016cca421812..f0d84646b9f01ad3ad209073f13b7b3ec21635d1 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -39,6 +39,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -59,6 +62,8 @@ cc_library( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 8246f76d3443d58f4174cc4f86100f54d6b46928..cd6e20b69366c064e20c6e0a7d1aebe6229690d8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -137,8 +137,7 @@ static StatusOr ToBuffer(LocalClient* client, /* static */ StatusOr LocalShapedBuffer::FromLiteral( - const Literal& argument, - const tensorflow::gtl::optional& shape_with_layout) { + const Literal& argument, const absl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { @@ -163,7 +162,7 @@ CompiledLocalComputation::CompiledLocalComputation( StatusOr> CompiledLocalComputation::Execute( const std::vector& arguments, - const std::vector>& shapes_with_layout) { + const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; @@ -194,7 +193,7 @@ StatusOr> CompiledLocalComputation::Execute( scoped_buffers.reserve(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { const Literal& argument = arguments[i]; - const tensorflow::gtl::optional& shape_with_layout = + const absl::optional& shape_with_layout = shapes_with_layout[i]; StatusOr pushed; @@ -252,7 +251,7 @@ StatusOr> CompiledLocalComputation::Execute( return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", - replica, statusor.status().ToString().c_str()); + replica, statusor.status().ToString()); } } @@ -260,7 +259,7 @@ StatusOr> CompiledLocalComputation::Execute( } LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles) { + absl::Span argument_handles) { LocalClient* client = GetOrCreateLocalClient(); std::vector argument_buffers; @@ -370,8 +369,7 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { } LocalOp LocalComputationBuilder::Broadcast( - const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { + const LocalOp& operand, absl::Span broadcast_sizes) { return xla::Broadcast(operand.op(), broadcast_sizes); } @@ -381,14 +379,14 @@ LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, return xla::Pad(operand.op(), padding_value.op(), padding_config); } -LocalOp LocalComputationBuilder::Reshape( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { +LocalOp LocalComputationBuilder::Reshape(const LocalOp& operand, + absl::Span dimensions, + absl::Span new_sizes) { return xla::Reshape(operand.op(), dimensions, new_sizes); } -LocalOp LocalComputationBuilder::Collapse( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Collapse(const LocalOp& operand, + absl::Span dimensions) { return xla::Collapse(operand.op(), dimensions); } @@ -396,10 +394,10 @@ LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { return xla::CrossReplicaSum(operand.op()); } -LocalOp LocalComputationBuilder::Slice( - const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { +LocalOp LocalComputationBuilder::Slice(const LocalOp& operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { return xla::Slice(operand.op(), start_indices, limit_indices, strides); } @@ -412,7 +410,7 @@ LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, LocalOp LocalComputationBuilder::DynamicSlice( const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } @@ -422,8 +420,8 @@ LocalOp LocalComputationBuilder::DynamicUpdateSlice( return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op()); } -LocalOp LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, int64 dimension) { +LocalOp LocalComputationBuilder::ConcatInDim(absl::Span operands, + int64 dimension) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -434,18 +432,16 @@ LocalOp LocalComputationBuilder::ConcatInDim( LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter) { return xla::SelectAndScatterWithGeneralPadding( operand.op(), select.computation(), window_dimensions, window_strides, padding, source.op(), init_value.op(), scatter.computation()); } -LocalOp LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { +LocalOp LocalComputationBuilder::Tuple(absl::Span elements) { std::vector xla_ops; xla_ops.reserve(elements.size()); for (const auto& op : elements) { @@ -472,10 +468,9 @@ LocalOp LocalComputationBuilder::DotGeneral( LocalOp LocalComputationBuilder::ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers); @@ -491,9 +486,8 @@ LocalOp LocalComputationBuilder::BitcastConvertType( return xla::BitcastConvertType(operand.op(), new_element_type); } -LocalOp LocalComputationBuilder::Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { +LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation, + absl::Span operands) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -503,19 +497,18 @@ LocalOp LocalComputationBuilder::Call( } LocalOp LocalComputationBuilder::Transpose( - const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + const LocalOp& operand, absl::Span permutation) { return xla::Transpose(operand.op(), permutation); } -LocalOp LocalComputationBuilder::Rev( - const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Rev(const LocalOp& operand, + absl::Span dimensions) { return xla::Rev(operand.op(), dimensions); } -LocalOp LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions) { +LocalOp LocalComputationBuilder::Map(absl::Span operands, + const LocalComputation& local_computation, + absl::Span dimensions) { std::vector xla_ops; xla_ops.reserve(operands.size()); for (const auto& op : operands) { @@ -529,7 +522,7 @@ LocalOp LocalComputationBuilder::Map( LocalOp LocalComputationBuilder::Reduce( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + absl::Span dimensions_to_reduce) { return xla::Reduce(operand.op(), init_value.op(), local_computation.computation(), dimensions_to_reduce); } @@ -537,9 +530,9 @@ LocalOp LocalComputationBuilder::Reduce( LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), window_dimensions, window_strides, padding); @@ -575,6 +568,16 @@ StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { return builder_.IsConstant(operand.op()); } +LocalOp LocalComputationBuilder::Sort(const LocalOp& operand, int64 dimension) { + return xla::Sort(operand.op(), absl::nullopt, dimension); +} + +LocalOp LocalComputationBuilder::SortKeyVal(const LocalOp& keys, + const LocalOp& values, + int64 dimension) { + return xla::Sort(keys.op(), values.op(), dimension); +} + StatusOr LocalComputationBuilder::BuildConstantSubGraph( const LocalOp& operand) { TF_ASSIGN_OR_RETURN(XlaComputation computation, @@ -590,10 +593,10 @@ StatusOr LocalComputationBuilder::BuildConstantSubGraph( #define _FORWARD_UNOP(method_name) \ _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD(method_name, LocalOp, \ - (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + absl::Span broadcast_dimensions), \ (lhs.op(), rhs.op(), broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ @@ -640,7 +643,6 @@ _FORWARD_UNOP(Sin) _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) -_FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) @@ -688,8 +690,7 @@ StatusOr DestructureLocalShapedBufferTuple( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", ShapeUtil::HumanString( - local_shaped_buffer->shaped_buffer()->on_device_shape()) - .c_str()); + local_shaped_buffer->shaped_buffer()->on_device_shape())); } DeviceMemoryAllocator* allocator = diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a568c24c6376e1fe17f5e5a4f6626bf0970985a3..78b3c598b97294d2ba4deb72ec9c1251ef68b7cf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace swig { @@ -60,8 +60,7 @@ StatusOr > TransferFromOutfeedLocalReplica( class LocalShapedBuffer { public: static StatusOr FromLiteral( - const Literal& argument, - const tensorflow::gtl::optional& shape_with_layout); + const Literal& argument, const absl::optional& shape_with_layout); LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; @@ -120,10 +119,10 @@ class CompiledLocalComputation { // shapes_with_layout. StatusOr > Execute( const std::vector& arguments, - const std::vector >& shapes_with_layout); + const std::vector >& shapes_with_layout); LocalShapedBuffer* ExecuteWithShapedBuffers( - tensorflow::gtl::ArraySlice argument_handles); + absl::Span argument_handles); private: std::unique_ptr executable_; @@ -200,46 +199,41 @@ class LocalComputationBuilder { LocalOp ConstantLiteral(const Literal& literal); LocalOp Broadcast(const LocalOp& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + absl::Span broadcast_sizes); LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, const PaddingConfig& padding_config); - LocalOp Reshape(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, absl::Span dimensions, + absl::Span new_sizes); - LocalOp Collapse(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, absl::Span dimensions); LocalOp CrossReplicaSum(const LocalOp& operand); - LocalOp Slice(const LocalOp& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); LocalOp SliceInDim(const LocalOp& operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno); LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, const LocalOp& start_indices); - LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(absl::Span operands, int64 dimension); LocalOp SelectAndScatterWithGeneralPadding( const LocalOp& operand, const LocalComputation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - const LocalOp& source, const LocalOp& init_value, - const LocalComputation& scatter); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding, const LocalOp& source, + const LocalOp& init_value, const LocalComputation& scatter); - LocalOp Tuple(tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(absl::Span elements); LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); @@ -250,10 +244,10 @@ class LocalComputationBuilder { LocalOp ConvGeneralDilated( const LocalOp& lhs, const LocalOp& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, + absl::Span window_strides, + absl::Span > padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); LocalOp ConvertElementType(const LocalOp& operand, @@ -263,28 +257,27 @@ class LocalComputationBuilder { PrimitiveType new_element_type); LocalOp Call(const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); LocalOp Transpose(const LocalOp& operand, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); - LocalOp Rev(const LocalOp& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, absl::Span dimensions); - LocalOp Map(tensorflow::gtl::ArraySlice operands, + LocalOp Map(absl::Span operands, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + absl::Span dimensions_to_reduce); LocalOp ReduceWindowWithGeneralPadding( const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice > padding); + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, const Shape& shape); @@ -301,6 +294,11 @@ class LocalComputationBuilder { StatusOr IsConstant(const LocalOp& operand); + LocalOp Sort(const LocalOp& operand, int64 dimension); + + LocalOp SortKeyVal(const LocalOp& keys, const LocalOp& values, + int64 dimension); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ @@ -312,7 +310,7 @@ class LocalComputationBuilder { #define _FORWARD_BINOP(method_name) \ _FORWARD(method_name, LocalOp, \ (const LocalOp& lhs, const LocalOp& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) + absl::Span broadcast_dimensions)) #define _FORWARD_TRIOP(method_name) \ _FORWARD(method_name, LocalOp, \ @@ -357,7 +355,6 @@ class LocalComputationBuilder { _FORWARD_UNOP(Tanh) _FORWARD_UNOP(IsFinite) _FORWARD_UNOP(Neg) - _FORWARD_UNOP(Sort) _FORWARD_UNOP(Sqrt) _FORWARD_UNOP(Rsqrt) _FORWARD_UNOP(Square) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 5d5a955bfee35b38a61b9a9f792c1b31259ce044..76c09512d82006af35e2508ce8e60f23a4c056c3 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,15 +22,15 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ArraySlice <- sequence of int -// ArraySlice <- sequence of LocalOp +// Span <- sequence of int +// Span <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) // <- object duck-typed as xla_client.Shape // std::vector <- sequence of xla_client.Shape objects // PrimitiveType <- int -// ArraySlice> <- sequence of int pairs +// Span> <- sequence of int pairs // PaddingConfig proto <- corresponding Python proto // ConvolutionDimensionNumbers proto <- corresponding Python proto // DotDimensionNumbers proto <- corresponding Python proto @@ -109,10 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" +#include "third_party/absl/strings/str_cat.h" +#include "third_party/absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" +#include "third_party/absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" @@ -154,8 +156,8 @@ bool HandleStringAttribute(PyObject* o, return true; // The attribute is None, which we consider ok. } if (!PyString_Check(attr)) { - string message = tensorflow::strings::Printf("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr).c_str()); + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. @@ -265,9 +267,9 @@ tensorflow::ImportNumpy(); $result = Py_None; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -297,9 +299,9 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ArraySlice +// Span -%typemap(in) tensorflow::gtl::ArraySlice( +%typemap(in) absl::Span( std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -321,7 +323,7 @@ tensorflow::ImportNumpy(); // LocalShapedBuffer* -%typemap(in) tensorflow::gtl::ArraySlice +%typemap(in) absl::Span (std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -409,10 +411,10 @@ tensorflow::ImportNumpy(); $1 = &temp; } -%typemap(in) const tensorflow::gtl::optional& ( - tensorflow::gtl::optional temp) { +%typemap(in) const absl::optional& ( + absl::optional temp) { if ($input == Py_None) { - temp = tensorflow::gtl::nullopt; + temp = absl::nullopt; $1 = &temp; } else { StatusOr statusor = numpy::XlaShapeFromPyShape($input); @@ -448,8 +450,8 @@ tensorflow::ImportNumpy(); $1 = &temps; } -%typemap(in) const std::vector >& ( - std::vector > temps) { +%typemap(in) const std::vector >& ( + std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; @@ -458,7 +460,7 @@ tensorflow::ImportNumpy(); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); if (o == Py_None) { - temps.push_back(tensorflow::gtl::nullopt); + temps.push_back(absl::nullopt); } else { StatusOr statusor = numpy::XlaShapeFromPyShape(o); Py_DECREF(o); @@ -494,9 +496,9 @@ tensorflow::ImportNumpy(); $1 = static_cast(value); } -// ArraySlice> +// Span> -%typemap(in) tensorflow::gtl::ArraySlice > +%typemap(in) absl::Span > (std::vector > temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); @@ -896,7 +898,7 @@ tensorflow::ImportNumpy(); if (o != Py_None) { StatusOr statusor = numpy::XlaShapeFromPyShape(o); if (!statusor.ok()) { - PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); SWIG_fail; } @@ -1011,6 +1013,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Pow; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::LocalComputationBuilder::SortKeyVal; %unignore xla::swig::LocalComputationBuilder::Sqrt; %unignore xla::swig::LocalComputationBuilder::Rsqrt; %unignore xla::swig::LocalComputationBuilder::Square; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 6f665faf61b25b23a32ce4d0a012543ba18d7e64..fc6511bef566cb6f4e0d4e52972954de0792e959 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -149,9 +151,7 @@ static int NumpyTypenum(PyObject* o) { // // NOTE: this is an internal helper for conversion to a C++, and so decrefs r. static string ExtractStringAndDecref(PyObject* r) { - auto error = [r] { - return tensorflow::strings::Printf("", r); - }; + auto error = [r] { return absl::StrFormat("", r); }; if (r == nullptr) { return error(); } @@ -191,8 +191,8 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to call method of shape object:", method)); + return error( + absl::StrCat("Failed to call method of shape object:", method)); } return result; }; @@ -281,15 +281,15 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { // Helper that retrieves the member with attr_name, stringifies it if is not // None, and returns it as a C++ string. -static tensorflow::gtl::optional GetAttrAsString( - PyObject* o, const string& attr_name) { +static absl::optional GetAttrAsString(PyObject* o, + const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } string result = PyObjectCppStr(attr); Py_DECREF(attr); @@ -298,48 +298,46 @@ static tensorflow::gtl::optional GetAttrAsString( // Helper that retrieves the member with attr_name, checks that it is an integer // if it is not None, and returns it as an int32 value. -static tensorflow::gtl::optional GetAttrAsInt32( - PyObject* o, const string& attr_name) { +static absl::optional GetAttrAsInt32(PyObject* o, + const string& attr_name) { if (!PyObject_HasAttrString(o, attr_name.c_str())) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); if (attr == Py_None) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } if (!CheckPyIntOrLong(attr)) { Py_DECREF(attr); - return tensorflow::gtl::nullopt; + return absl::nullopt; } long value = PyIntOrPyLongToLong(attr); // NOLINT Py_DECREF(attr); if (value == -1 && PyErr_Occurred() != nullptr) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } if (static_cast(value) != value) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } return value; } StatusOr OpMetadataFromPyObject(PyObject* o) { OpMetadata result; - tensorflow::gtl::optional op_type = GetAttrAsString(o, "op_type"); + absl::optional op_type = GetAttrAsString(o, "op_type"); if (op_type.has_value()) { result.set_op_type(op_type.value()); } - tensorflow::gtl::optional op_name = GetAttrAsString(o, "op_name"); + absl::optional op_name = GetAttrAsString(o, "op_name"); if (op_name.has_value()) { result.set_op_name(op_name.value()); } - tensorflow::gtl::optional source_file = - GetAttrAsString(o, "source_file"); + absl::optional source_file = GetAttrAsString(o, "source_file"); if (source_file.has_value()) { result.set_source_file(source_file.value()); } - tensorflow::gtl::optional source_line = - GetAttrAsInt32(o, "source_line"); + absl::optional source_line = GetAttrAsInt32(o, "source_line"); if (source_line.has_value()) { result.set_source_line(source_line.value()); } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index a67c93a4fb7413f9bbcb9afd92c36fd118836e1f..8cae1751853f3cd18033ecf6edca40bf99c6d917 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -25,9 +25,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/python/lib/core/numpy.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index a2c6fc344d192265d536ef7e23ad5c6d7c847014..fa4366ff0789a3d05c26479a746a18dfcf7e902b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -105,7 +105,6 @@ _UNARY_OPS = [ 'Square', 'Reciprocal', 'Neg', - 'Sort', 'Erf', 'Erfc', 'ErfInv', @@ -1218,6 +1217,14 @@ class ComputationBuilder(object): lhs_dilation, rhs_dilation, dimension_numbers) + def Sort(self, operand, dimension=-1): + """Enqueues a sort operation onto the computation.""" + return self._client.Sort(operand, dimension) + + def SortKeyVal(self, keys, values, dimension=-1): + """Enqueues a key-value sort operation onto the computation.""" + return self._client.SortKeyVal(keys, values, dimension) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a803520876952a0ab67ecb827b1f256c915335f9..a4854f593f0a579e3461b35033620e762593c6a6 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -43,7 +44,7 @@ std::unique_ptr> MatmulArray2DImpl( int m = lhs.height(); int n = rhs.width(); int k = lhs.width(); - auto result = MakeUnique>(m, n); + auto result = absl::make_unique>(m, n); // Because Eigen is a header-oriented library, make sure that the Eigen code // is the same as the code used by the CPU backend (otherwise the linker will // randomly pick *some* definition). @@ -77,7 +78,8 @@ std::unique_ptr> MatmulArray2DImpl( /* static */ std::unique_ptr> ReferenceUtil::Array2DF32ToF64( const Array2D& input) { - auto result = MakeUnique>(input.height(), input.width()); + auto result = + absl::make_unique>(input.height(), input.width()); for (int64 rowno = 0; rowno < input.height(); ++rowno) { for (int64 colno = 0; colno < input.height(); ++colno) { (*result)(rowno, colno) = input(rowno, colno); @@ -106,17 +108,15 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. Array4D a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); - a4dlhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); - }); + a4dlhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); + }); Array4D a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); - a4drhs.Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); - }); + a4drhs.Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); + }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; dnums2d.add_input_spatial_dimensions(3); @@ -126,13 +126,12 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); - auto convr3 = MakeUnique>(convr4->planes(), convr4->depth(), - convr4->height()); - convr4->Each( - [&](tensorflow::gtl::ArraySlice indices, float* value_ptr) { - CHECK_EQ(indices[3], 0); - convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; - }); + auto convr3 = absl::make_unique>( + convr4->planes(), convr4->depth(), convr4->height()); + convr4->Each([&](absl::Span indices, float* value_ptr) { + CHECK_EQ(indices[3], 0); + convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; + }); return convr3; } @@ -187,11 +186,11 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const absl::Span& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -201,7 +200,7 @@ ReferenceUtil::ReduceWindow1DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0]); + auto result = absl::make_unique>(window_counts[0]); // Do a full 1D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -219,10 +218,11 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { +ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, + float init, + const absl::Span& window, + const absl::Span& stride, + Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; return ReduceWindow1DGeneric( @@ -234,9 +234,9 @@ ReferenceUtil::ReduceWindow1DAdd( ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -247,7 +247,8 @@ ReferenceUtil::ReduceWindow2DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1]); + auto result = + absl::make_unique>(window_counts[0], window_counts[1]); // Do a full 2D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { @@ -273,8 +274,8 @@ ReferenceUtil::ReduceWindow2DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -284,8 +285,8 @@ ReferenceUtil::ReduceWindow2DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -296,8 +297,8 @@ ReferenceUtil::ReduceWindow2DGeneric( WindowCount(dim_lengths[i], window[i], stride[i], padding); pad_low[i] = padding_both[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2]); for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -331,8 +332,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -344,9 +345,9 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding) { + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -358,8 +359,8 @@ ReferenceUtil::ReduceWindow4DGeneric( window_util::StridedBound(padded_width, window[i], stride[i]); pad_low[i] = padding[i].first; } - auto result = MakeUnique>(window_counts[0], window_counts[1], - window_counts[2], window_counts[3]); + auto result = absl::make_unique>( + window_counts[0], window_counts[1], window_counts[2], window_counts[3]); // Do a full 4D reduce window. for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { @@ -399,8 +400,8 @@ ReferenceUtil::ReduceWindow4DGeneric( /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding) { + const absl::Span& window, + const absl::Span& stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -421,13 +422,15 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::SelectAndScatter4DGePlus( - const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding) { +ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, + const Array4D& source, + float init, + const absl::Span& window, + const absl::Span& stride, + bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; - auto result = MakeUnique>(operand.n1(), operand.n2(), - operand.n3(), operand.n4()); + auto result = absl::make_unique>(operand.n1(), operand.n2(), + operand.n3(), operand.n4()); std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -583,12 +586,12 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = - MakeUnique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal->shape().dimensions(0), + result_literal->shape().dimensions(1), + result_literal->shape().dimensions(2), + result_literal->shape().dimensions(3)); - result->Each([&](tensorflow::gtl::ArraySlice indices, float* value) { + result->Each([&](absl::Span indices, float* value) { *value = result_literal->Get(indices); }); @@ -601,7 +604,7 @@ ReferenceUtil::ReduceToColArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < rows; ++i) { float acc = init; for (int64 j = 0; j < cols; ++j) { @@ -618,7 +621,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& reduce_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(); + auto result = absl::make_unique>(); for (int64 i = 0; i < cols; ++i) { float acc = init; for (int64 j = 0; j < rows; ++j) { @@ -630,8 +633,7 @@ ReferenceUtil::ReduceToRowArray2D( } /*static*/ std::vector ReferenceUtil::Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function) { std::vector result; CHECK_EQ(dims.size(), 3); @@ -674,8 +676,8 @@ ReferenceUtil::ReduceToRowArray2D( /* static */ std::unique_ptr> ReferenceUtil::Broadcast1DTo4D( const std::vector& array, const std::vector& bounds, int64 broadcast_from_dim) { - auto result = - MakeUnique>(bounds[0], bounds[1], bounds[2], bounds[3]); + auto result = absl::make_unique>(bounds[0], bounds[1], + bounds[2], bounds[3]); for (int64 i = 0; i < result->n1(); ++i) { for (int64 j = 0; j < result->n2(); ++j) { for (int64 k = 0; k < result->n3(); ++k) { @@ -704,13 +706,12 @@ ReferenceUtil::ReduceToRowArray2D( } /* static */ std::unique_ptr> ReferenceUtil::Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function) { CHECK_EQ(dims.size(), 1); int64 rows = dims[0] == 0 ? array.n2() : array.n1(); int64 cols = dims[0] == 2 ? array.n2() : array.n3(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); result->Fill(init); for (int i0 = 0; i0 < array.n1(); ++i0) { for (int i1 = 0; i1 < array.n2(); ++i1) { @@ -730,7 +731,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j)); @@ -746,7 +747,7 @@ ReferenceUtil::ReduceToRowArray2D( CHECK_EQ(lhs.width(), rhs.width()); int64 rows = lhs.height(); int64 cols = rhs.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); @@ -760,7 +761,7 @@ ReferenceUtil::ReduceToRowArray2D( const std::function& map_function) { int64 rows = matrix.height(); int64 cols = matrix.width(); - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); for (int64 i = 0; i < rows; ++i) { for (int64 j = 0; j < cols; ++j) { (*result)(i, j) = map_function(matrix(i, j), i, j); diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 8fa6961d197dce519cf151283b8bc0836a4615c0..9ce098029dbc35f6b4bab2efd77bee2b7e1a6255 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -22,14 +22,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,8 @@ class ReferenceUtil { template static std::unique_ptr> TransposeArray2D( const Array2D& operand) { - auto result = MakeUnique>(operand.width(), operand.height()); + auto result = + absl::make_unique>(operand.width(), operand.height()); for (int64 w = 0; w < operand.width(); ++w) { for (int64 h = 0; h < operand.height(); ++h) { (*result)(w, h) = operand(h, w); @@ -143,8 +144,7 @@ class ReferenceUtil { // Returns the result of reducing the 4D array to a vector, reducing away // the dimensions specified in dims. static std::vector Reduce4DTo1D( - const Array4D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array4D& array, float init, absl::Span dims, const std::function& reduce_function); // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`. @@ -155,8 +155,7 @@ class ReferenceUtil { // Returns the result of reducing the 3D array to a 2D array, reducing away // the dimensions specified in dims. static std::unique_ptr> Reduce3DTo2D( - const Array3D& array, float init, - tensorflow::gtl::ArraySlice dims, + const Array3D& array, float init, absl::Span dims, const std::function& reduce_function); // Applies map_function to each element in the input (2D array) and returns @@ -178,47 +177,47 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const tensorflow::gtl::ArraySlice& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& operand, float init, + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow2DAdd( const Array2D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( const Array3D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const tensorflow::gtl::ArraySlice& operand, float init, + const absl::Span& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, Padding padding); + const absl::Span& window, + const absl::Span& stride, Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, - const tensorflow::gtl::ArraySlice>& padding); + const absl::Span& window, + const absl::Span& stride, + const absl::Span>& padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -231,8 +230,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const tensorflow::gtl::ArraySlice& window, - const tensorflow::gtl::ArraySlice& stride, bool same_padding); + const absl::Span& window, + const absl::Span& stride, bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -242,7 +241,7 @@ class ReferenceUtil { const Array2D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 2); - auto result = MakeUnique>( + auto result = absl::make_unique>( concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(), concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2()); for (int64 i0 = 0; i0 < result->n1(); ++i0) { @@ -276,7 +275,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2]); + auto result = + absl::make_unique>(out_dims[0], out_dims[1], out_dims[2]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -310,8 +310,8 @@ class ReferenceUtil { out_dims[i] = lhs_dims[i] + rhs_dims[i]; } } - auto result = MakeUnique>(out_dims[0], out_dims[1], out_dims[2], - out_dims[3]); + auto result = absl::make_unique>(out_dims[0], out_dims[1], + out_dims[2], out_dims[3]); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -332,8 +332,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D( - const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + static std::vector ClampSlice1D(const absl::Span& input, + int64 start, int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { @@ -355,9 +355,9 @@ class ReferenceUtil { CHECK_LE(limits[1], input.n2()); CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { (*result)(i0, i1) = @@ -381,10 +381,10 @@ class ReferenceUtil { CHECK_GE(strides[0], 1); CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { @@ -415,11 +415,11 @@ class ReferenceUtil { CHECK_GE(strides[1], 1); CHECK_GE(strides[2], 1); CHECK_GE(strides[3], 1); - auto result = - MakeUnique>(CeilOfRatio(limits[0] - starts[0], strides[0]), - CeilOfRatio(limits[1] - starts[1], strides[1]), - CeilOfRatio(limits[2] - starts[2], strides[2]), - CeilOfRatio(limits[3] - starts[3], strides[3])); + auto result = absl::make_unique>( + CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { @@ -460,8 +460,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& input, F&& map_function) { - auto result = MakeUnique>(input.planes(), input.depth(), - input.height(), input.width()); + auto result = absl::make_unique>( + input.planes(), input.depth(), input.height(), input.width()); for (int64 plane = 0; plane < input.planes(); ++plane) { for (int64 depth = 0; depth < input.depth(); ++depth) { for (int64 height = 0; height < input.height(); ++height) { @@ -495,8 +495,8 @@ class ReferenceUtil { template static std::unique_ptr> MapWithIndexArray4D( const Array4D& lhs, const Array4D& rhs, F&& map_function) { - auto result = MakeUnique>(lhs.planes(), lhs.depth(), - lhs.height(), lhs.width()); + auto result = absl::make_unique>(lhs.planes(), lhs.depth(), + lhs.height(), lhs.width()); for (int64 plane = 0; plane < lhs.planes(); ++plane) { for (int64 depth = 0; depth < lhs.depth(); ++depth) { for (int64 height = 0; height < lhs.height(); ++height) { @@ -530,7 +530,7 @@ class ReferenceUtil { int64 out1 = in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - auto result = MakeUnique>(out0, out1); + auto result = absl::make_unique>(out0, out1); result->Fill(pad); int64 o0 = low_padding0; for (int64 i0 = 0; i0 < in0; ++i0) { @@ -631,7 +631,7 @@ class ReferenceUtil { Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], output_bounds[3]); result.Each( - [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + [&](absl::Span indices, NativeT* value) { for (int i = 0; i < 4; ++i) { bool in_low_padding = indices[i] < pad_low[i]; bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; @@ -669,7 +669,7 @@ class ReferenceUtil { static std::unique_ptr> ApplyElementwise2D( F&& f, const Array2D& array1, const Array2D&... arrays) { AssertSameSize2D(array1, arrays...); - auto result = MakeUnique>(array1.n1(), array1.n2()); + auto result = absl::make_unique>(array1.n1(), array1.n2()); for (int64 i = 0; i < array1.n1(); ++i) { for (int64 j = 0; j < array1.n2(); ++j) { (*result)(i, j) = f(array1(i, j), arrays(i, j)...); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 8091bed4996a753649a5ecedda69a1ae48fb5897..3ec0192148492c2516bf1c14fd4b960b08014388 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -36,7 +36,7 @@ namespace { class ReferenceUtilTest : public ::testing::Test { protected: ReferenceUtilTest() { - matrix_ = MakeUnique>(rows_, cols_); + matrix_ = absl::make_unique>(rows_, cols_); // [1.f 2.f 3.f] // [4.f 5.f 6.f] for (int64 i = 0; i < rows_; ++i) { @@ -112,8 +112,8 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { } TEST_F(ReferenceUtilTest, MapArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); @@ -126,8 +126,8 @@ TEST_F(ReferenceUtilTest, MapArray4D) { } TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { - auto input = MakeUnique>(/*planes=*/2, /*depth=*/3, - /*height=*/4, /*width=*/5); + auto input = absl::make_unique>(/*planes=*/2, /*depth=*/3, + /*height=*/4, /*width=*/5); input->FillWithMultiples(1.0f); auto subtract_index = [](float value, int64 plane, int64 depth, int64 height, int64 width) { diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 44b22a5586dee3f7dd8ea0edbf9deb2090986ac8..97fcd37f6b89d6dd737c233ef19f55a8faa1b624 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -43,6 +43,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -62,6 +63,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 67886761813f0bb45a600661b017be91ffeade73..43fd8fe1bd0f41eb2ac5c42021a8ca4f63282646 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/subprocess.h" @@ -46,7 +46,7 @@ class GRPCClientTestBase : public ::testing::Test { int port = tensorflow::internal::PickUnusedPortOrDie(); subprocess_.SetProgram( service_main_path, - {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + {service_main_path, absl::StrFormat("--port=%d", port)}); subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_DUPPARENT); subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, @@ -54,9 +54,8 @@ class GRPCClientTestBase : public ::testing::Test { CHECK(subprocess_.Start()); LOG(INFO) << "Launched subprocess"; - auto channel = - ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), - ::grpc::InsecureChannelCredentials()); + auto channel = ::grpc::CreateChannel(absl::StrFormat("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); channel->WaitForConnected(gpr_time_add( gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); LOG(INFO) << "Channel to server is connected on port " << port; diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index c68c857c304138ff4318e243f66547c6acce1005..d6b5149a24c491d1e9d7cd9119b36d7eb2ad65d3 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -18,8 +18,8 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -44,7 +44,7 @@ int RealMain(int argc, char** argv) { xla::GRPCService::NewService().ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(tensorflow::strings::Printf("localhost:%d", port)); + string server_address(absl::StrFormat("localhost:%d", port)); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7331d2b54cfdb853ff6fcb2e02c5bdd9cf716779..26b48cf4196ce24a8a20f407f698d951e18193f9 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -99,9 +100,11 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -175,6 +178,10 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -191,6 +198,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -226,6 +234,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -237,6 +246,12 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -263,6 +278,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -311,6 +327,11 @@ cc_library( "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -337,7 +358,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -363,6 +384,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/types:span", ], ) @@ -389,7 +411,8 @@ cc_library( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -419,6 +442,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -449,6 +473,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -517,6 +544,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -552,6 +580,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -574,6 +603,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -615,6 +647,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) @@ -647,6 +683,10 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -669,6 +709,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -719,6 +760,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -736,6 +781,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -753,9 +799,11 @@ cc_library( ":hlo_execution_profile", ":hlo_graph_dumper", ":hlo_proto", + ":maybe_owning_device_memory", ":shaped_buffer", ":stream_pool", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -766,6 +814,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", ], ) @@ -784,6 +836,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", ], ) @@ -813,6 +866,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -831,6 +887,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -847,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -864,6 +923,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -874,6 +936,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -908,6 +971,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -917,12 +982,14 @@ tf_cc_test( deps = [ ":buffer_liveness", ":hlo", + ":hlo_dataflow_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -950,6 +1017,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -975,8 +1046,10 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -996,6 +1069,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1031,6 +1106,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1049,6 +1125,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1059,12 +1136,15 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -1074,6 +1154,7 @@ cc_library( hdrs = ["hlo_module_group_util.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_group_metadata", ":hlo_reachability", "//tensorflow/compiler/xla:status", @@ -1082,6 +1163,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1101,6 +1185,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1108,17 +1193,18 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ - ":buffer_value", ":heap_simulator", ":hlo", + ":hlo_dce", ":hlo_ordering", + ":hlo_parser", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1142,6 +1228,7 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1167,6 +1254,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1181,6 +1269,9 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1196,8 +1287,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1216,6 +1309,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1231,6 +1326,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1245,6 +1341,7 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1267,6 +1364,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1276,6 +1374,7 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_creation_utils", ":hlo_pass", ":hlo_query", @@ -1289,6 +1388,11 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -1298,6 +1402,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal", @@ -1312,6 +1417,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1323,8 +1430,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1377,6 +1483,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1399,6 +1506,41 @@ tf_cc_test( ], ) +cc_library( + name = "convolution_feature_group_converter", + srcs = ["convolution_feature_group_converter.cc"], + hdrs = ["convolution_feature_group_converter.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "convolution_feature_group_converter_test", + size = "small", + srcs = ["convolution_feature_group_converter_test.cc"], + deps = [ + ":convolution_feature_group_converter", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + cc_library( name = "while_loop_analysis", srcs = ["while_loop_analysis.cc"], @@ -1406,8 +1548,7 @@ cc_library( deps = [ ":hlo", ":hlo_evaluator", - "//tensorflow/compiler/xla:literal", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -1422,6 +1563,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -1435,6 +1578,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1549,6 +1693,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1569,6 +1714,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1583,6 +1729,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -1602,6 +1749,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -1621,6 +1769,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1634,6 +1784,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1671,6 +1823,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/types:span", ], ) @@ -1711,6 +1864,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1725,6 +1880,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1756,6 +1912,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1772,6 +1930,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1787,6 +1947,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1814,6 +1976,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1831,6 +1995,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1849,6 +2016,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1890,6 +2061,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1926,6 +2099,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1946,6 +2120,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1966,6 +2142,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1983,6 +2160,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -1995,7 +2173,6 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -2003,6 +2180,11 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -2053,6 +2235,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -2075,6 +2261,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2142,7 +2329,10 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2179,13 +2369,16 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", - ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2225,6 +2418,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2270,6 +2464,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -2306,6 +2501,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2343,6 +2541,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2359,6 +2558,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2369,6 +2569,7 @@ tf_cc_test( ":hlo", ":hlo_constant_folding", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2390,6 +2591,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2404,6 +2606,8 @@ cc_library( "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -2464,6 +2668,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2510,6 +2715,22 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "maybe_owning_device_memory", + srcs = [ + "maybe_owning_device_memory.cc", + ], + hdrs = [ + "maybe_owning_device_memory.h", + ], + deps = [ + ":device_memory_allocator", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", ], ) @@ -2519,6 +2740,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2527,11 +2749,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2563,10 +2788,11 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -2579,6 +2805,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2615,8 +2842,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -2650,6 +2877,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], alwayslink = 1, ) @@ -2666,6 +2896,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2747,9 +2978,9 @@ cc_library( hdrs = ["stream_pool.h"], deps = [ "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -2847,6 +3078,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -2867,7 +3100,7 @@ cc_library( hdrs = ["tuple_util.h"], deps = [ ":hlo", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -2893,7 +3126,8 @@ cc_library( ":hlo_creation_utils", ":tuple_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -2907,6 +3141,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2922,6 +3157,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2949,6 +3186,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2982,13 +3220,13 @@ cc_library( cc_library( name = "source_map_util", - srcs = ["source_map_util.cc"], + srcs = [], hdrs = ["source_map_util.h"], deps = [ ":executable", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3003,6 +3241,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -3034,8 +3276,11 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3044,11 +3289,13 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) @@ -3067,6 +3314,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 37834e1cc2657ff56f65a4f94eb973b9022eb8e1..7c078f07d72ab4243d50b7f7910cb7c794e306c4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -22,13 +22,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -40,8 +47,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -122,6 +127,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleImag(HloInstruction* imag) override; + Status HandleIota(HloInstruction* instruction) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleDivide(HloInstruction* divide) override; @@ -266,7 +273,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr OptimizeDotOfConcat(HloInstruction* dot); StatusOr OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); StatusOr OptimizeDotOfGather(HloInstruction* dot); @@ -444,8 +451,7 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { Status AlgebraicSimplifierVisitor::HandleConcatenate( HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + absl::Span operands(concatenate->operands()); if (operands.size() == 1) { // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); @@ -540,7 +546,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = MakeUnique( + std::unique_ptr unique_scalar = absl::make_unique( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -548,6 +554,14 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { constant, HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); } + + // If a literal is an increasing sequence from zero, replace it with an iota. + if (ShapeUtil::Rank(constant->shape()) == 1 && + ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsR1Iota()) { + return ReplaceWithNewInstruction( + constant, HloInstruction::CreateIota(constant->shape(), 0)); + } return Status::OK(); } @@ -575,7 +589,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { namespace { template Status InvertConstant(const HloInstruction& constant, Literal* result) { - return result->Populate([&](tensorflow::gtl::ArraySlice indices) { + return result->Populate([&](absl::Span indices) { return T{1.0} / constant.literal().Get(indices); }); } @@ -827,18 +841,18 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, - OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, rhs_contracting_dim, /*swapped=*/false)); if (optimized_lhs_concat) { return optimized_lhs_concat; } - return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, lhs_contracting_dim, /*swapped=*/true); } StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && lhs->concatenate_dimension() == lhs_contracting_dim && @@ -937,11 +951,12 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( } auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); + new_dot->set_precision_config(dot.precision_config()); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_shape, HloOpcode::kAdd, add_result, new_dot)); + dot.shape(), HloOpcode::kAdd, add_result, new_dot)); } else { add_result = new_dot; } @@ -1040,6 +1055,7 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( memoized_shape, left_operand, right_operand, dnums)); + memoized_inst->set_precision_config(dot->precision_config()); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1137,6 +1153,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers)); + new_dot->set_precision_config(dot->precision_config()); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1232,9 +1249,8 @@ namespace { // return value = {1, 3} // // Precondition: input_dim_indices is sorted. -std::pair> ReshapeLeavesDimensionsUnmodified( - const HloInstruction* hlo, - tensorflow::gtl::ArraySlice input_dim_indices) { +absl::optional> ReshapeLeavesDimensionsUnmodified( + const HloInstruction* hlo, absl::Span input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); @@ -1252,11 +1268,11 @@ std::pair> ReshapeLeavesDimensionsUnmodified( } if (i >= unmodified_dims.size() || unmodified_dims[i].first != input_dim_index) { - return std::make_pair(false, std::vector()); + return absl::nullopt; } output_dim_indices.push_back(unmodified_dims[i].second); } - return std::make_pair(true, output_dim_indices); + return output_dim_indices; } // Returns true if the output of "instruction" is a permutation of the @@ -1385,6 +1401,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } + // broadcast(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateIota( + broadcast->shape(), + dims[Cast(operand)->iota_dimension()])); + } + // Merge two consecutive broadcasts into a single one. if (operand->opcode() == HloOpcode::kBroadcast) { std::vector new_dimensions; @@ -1439,6 +1464,19 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { + // iota -> zero if the iota dimension never produces an element other than + // zero. + auto* iota = Cast(instruction); + if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { + auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + return ReplaceWithNewInstruction( + iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( @@ -1705,16 +1743,33 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = reshape->shape(); + return ReplaceInstruction(reshape, operand); + } if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); - if (opt_dims.first) { + if (opt_dims.has_value()) { return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), - opt_dims.second)); + *opt_dims)); + } + } + + // reshape(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + auto* iota = Cast(operand); + auto opt_dims = + ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); + if (opt_dims.has_value()) { + CHECK_EQ(opt_dims->size(), 1); + return ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); } } @@ -1748,8 +1803,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } auto is_unstrided_slice = [](const HloInstruction* hlo) { - return c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { @@ -1811,7 +1866,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (ShapeUtil::IsZeroElementArray(arg->shape()) || ShapeUtil::IsZeroElementArray(reduce->shape())) { @@ -1926,7 +1981,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // This should make fusion easier or use less memory bandwidth in the unfused // case. if (arg->opcode() == HloOpcode::kConcatenate && - c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + absl::c_linear_search(reduce->dimensions(), + arg->concatenate_dimension())) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( @@ -1979,9 +2035,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() - << (convert != nullptr ? tensorflow::strings::StrCat( - "\nvia convert: ", convert->ToString()) - : ""); + << (convert != nullptr + ? absl::StrCat("\nvia convert: ", convert->ToString()) + : ""); // Do not fold interior padding into ReduceWindow since the backends do not // support it. @@ -2144,6 +2200,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = transpose->shape(); + return ReplaceInstruction(transpose, operand); + } + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); @@ -2167,7 +2228,141 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( .CloneToUnique())), {})); } + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { + return false; + } + + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); + } + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; + }()); + + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr { + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; + } + + const auto& padding = rhs->padding_config(); + + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { + return false; + } + } + + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); + + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; + } + + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } + + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; + } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } + + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; + }()); + + if (folded_filter_pad) { + return Status::OK(); + } + if (!enable_conv_simplification_) { return Status::OK(); } @@ -2184,8 +2379,6 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( return Status::OK(); } - const ConvolutionDimensionNumbers& dnums = - convolution->convolution_dimension_numbers(); const Shape& input_shape = lhs->shape(); const Shape& filter_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -2285,6 +2478,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); + dot->set_precision_config(convolution->precision_config()); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index c48196e861a559a5abfa360841ec70b39356fa2b..b864c372fa5877ca329d2efbbf7d747c763ae2c0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; - tensorflow::StringPiece name() const override { return "algsimp"; } + absl::string_view name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 862cbeeba6b82e1f24a6616b3237dc47d022e9af..43a891e4fa163e833692a8e71b8f2f21d377e323 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -18,11 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -34,13 +38,12 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -using ::testing::ElementsAre; namespace xla { namespace { +using ::testing::ElementsAre; + namespace op = xla::testing::opcode_matchers; AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { @@ -290,6 +293,21 @@ TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { EXPECT_THAT(root, op::Constant()); } +TEST_F(AlgebraicSimplifierTest, IotaToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -513,7 +531,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({0.f, 1.f, 2.f}))); + LiteralUtil::CreateR1({1.f, 2.f, 3.f}))); builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, constant)); @@ -1428,6 +1446,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } +// Test transforming reshapes and transposes of rng. +TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + HloComputation::Builder builder(TestName()); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction* rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}), + RandomDistribution::RNG_UNIFORM, {zero, one})); + + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0})); + Shape reshape_shape = builder + .AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {4}), transpose)) + ->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // same shape as the reshape. + EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), + reshape_shape)); +} + // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { HloComputation::Builder builder(TestName()); @@ -1789,6 +1838,126 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { op::Reshape(op::Broadcast(param))); } +TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); + Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaEffectiveScalar) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {1, 1}), 0)); + auto result_shape = iota->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + auto root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(0.0f, root->operand(0)->literal().GetFirstElement()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2_6) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2}), 1)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_EQ(Cast(computation->root_instruction()) + ->iota_dimension(), + 3); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x2_6x1x1x2) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 2}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 2}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + const int64 iota_dim = + Cast(computation->root_instruction()) + ->iota_dimension(); + EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1975,6 +2144,264 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); } +// Used for TEST_Ps that test merging (or not) of a kPad instruction into a +// convolution's Window. +struct ConvPaddingTestcase { + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window) + : ConvPaddingTestcase(padding, orig_conv_window, expected_conv_window, + /*pad_value=*/0) {} + + ConvPaddingTestcase(absl::string_view padding, + absl::string_view orig_conv_window, + absl::string_view expected_conv_window, float pad_value) + : padding(padding), + orig_conv_window(orig_conv_window), + expected_conv_window(expected_conv_window), + pad_value(pad_value) {} + + string ToString() const { + return absl::StrFormat( + "padding=%s, orig_conv_window=%s, expected_conv_window=%s, " + "pad_value=%f", + padding, orig_conv_window, expected_conv_window, pad_value); + } + + string padding; + string orig_conv_window; + string expected_conv_window; + float pad_value; +}; + +// ConvInputPaddingTest (and its one associated TEST_P testcase) checks that a +// computation that does +// +// conv(pad(param0, padding=padding), param1), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvInputPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvInputPaddingTestCases, ConvInputPaddingTest, + ::testing::ValuesIn(std::vector{ + // Merge this edge padding into the conv. + {"0_0x0_0x1_1x2_2", "", "pad=1_1x2_2"}, + // Merge this edge padding with the conv's edge padding. + {"0_0x0_0x1_2x3_4", "pad=10_10x20_20", "pad=11_12x23_24"}, + // Merge this interior-padded kPad with the unpadded conv. The 3x6 + // interior padding gets transformed to 4x7 conv lhs dilation. + {"0_0x0_0x1_2_3x4_5_6", "", "pad=1_2x4_5 lhs_dilate=4x7"}, + // kPad has dilation on one dim, conv has it on the other; merge them. + {"0_0x0_0x0_0_1x0_0_0", "lhs_dilate=1x10", "lhs_dilate=2x10"}, + // kPad has dilation and edge padding on one dim, conv has them on the + // other; merge them. + {"0_0x0_0x0_1_1x0_0_0", "pad=0_0x3_0 lhs_dilate=1x10", + "pad=0_1x3_0 lhs_dilate=2x10"}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x1_1x2_2", "", "", /*pad_value=*/1}, + + // We refuse to transform the following because on some dimension, one + // of the kPad and conv has dilation and the other has some sort of + // padding. + {"0_0x0_0x0_0_1x0_0", "pad=1_0x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "pad=0_1x0_0", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x1_0_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_1_0x0_0", "lhs_dilate=2x1", ""}, + {"0_0x0_0x0_0_1x0_0", "lhs_dilate=2x1", ""}, + + // We can't merge feature or batch padding into the conv. + {"1_0x0_0x0_0x0_0", "", ""}, + {"0_0x1_0x0_0x0_0", "", ""}, + })); + +TEST_P(ConvInputPaddingTest, DoTest) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1024, 128, 100, 100}), // bf01 + "input")); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* lhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(input->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + input, pad_value, padding_config)); + + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, + ShapeUtil::MakeShape( + F32, {lhs_pad->shape().dimensions(1), 256, 3, 3}), // io01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = + ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window)) + .ValueOrDie(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), + window, dnums) + .ValueOrDie(), + lhs_pad, filter, window, dnums)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrCat("size=3x3 ", testcase.expected_conv_window)); + } +} + +// ConvFilterPaddingTest (and its one associated TEST_P) checks that a +// computation that does +// +// conv(param0, pad(param1, padding=padding)), window=orig_conv_window +// +// gets transformed by AlgebraicSimplifier to +// +// conv(param0, param1), window=expected_conv_window +// +// or, if expected_conv_window is the empty string, checks that +// AlgebraicSimplifier does *not* transform the original convolution. +class ConvFilterPaddingTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface {}; + +INSTANTIATE_TEST_CASE_P( + ConvFilterPaddingTestCases, ConvFilterPaddingTest, + ::testing::ValuesIn(std::vector{ + // Can only merge interior padding on the filter's spatial dimensions; + // all + // other paddings (edge padding and interior padding on the channel + // dims) + // should be rejected out of hand. + {"1_0_0x0_0_0x0_0x0_0", "", ""}, + {"0_1_0x0_0_0x0_0x0_0", "", ""}, + {"0_0_1x0_0_0x0_0x0_0", "", ""}, + {"0_0_0x1_0_0x0_0x0_0", "", ""}, + {"0_0_0x0_1_0x0_0x0_0", "", ""}, + {"0_0_0x0_0_1x0_0x0_0", "", ""}, + {"0_0_0x0_0_0x1_0x0_0", "", ""}, + {"0_0_0x0_0_0x0_1x0_0", "", ""}, + {"0_0_0x0_0_0x0_0x1_0", "", ""}, + {"0_0_0x0_0_0x0_0x0_1", "", ""}, + + // Interior padding on channel dims can be merged into the conv, so long + // as the conv and pad don't have interior padding on the same dim. + {"0_0x0_0x0_0_5x0_0", "", "rhs_dilate=6x1"}, + {"0_0x0_0x0_0x0_0_10", "", "rhs_dilate=1x11"}, + {"0_0x0_0x0_0_10x0_0_100", "", "rhs_dilate=11x101"}, + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=1x10", "rhs_dilate=2x10"}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x1", "rhs_dilate=10x6"}, + + // Can't merge if for a given dim there's interior padding on both the + // pad and conv. + {"0_0x0_0x0_0_1x0_0", "rhs_dilate=2x10", ""}, + {"0_0x0_0x0_0x0_0_5", "rhs_dilate=10x2", ""}, + + // Don't transform if the pad value is nonzero. + {"0_0x0_0x0_0_5x0_0", "", "", /*pad_value=*/1}, + })); + +TEST_P(ConvFilterPaddingTest, DoIt) { + ConvPaddingTestcase testcase = GetParam(); + + // It would be better to put the testcase's ToString into the test name, but + // gUnit has constraints on what can go into test names, and any reasonable + // implementation of ToString() seems to violate them. + SCOPED_TRACE(testcase.ToString()); + + auto builder = HloComputation::Builder(TestName()); + auto* pad_value = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(testcase.pad_value))); + auto* filter = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {128, 256, 3, 3}), // io01 + "input")); + PaddingConfig padding_config = + ParsePaddingConfig(testcase.padding).ValueOrDie(); + auto* rhs_pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeInference::InferPadShape(filter->shape(), pad_value->shape(), + padding_config) + .ValueOrDie(), + filter, pad_value, padding_config)); + + auto* input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeShape( + F32, {1024, rhs_pad->shape().dimensions(0), 100, 100}), // bf01 + "input")); + + ConvolutionDimensionNumbers dnums = + ParseConvolutionDimensionNumbers("bf01_io01->bf01").ValueOrDie(); + Window window = ParseWindow(absl::StrFormat("size=%dx%d %s", + rhs_pad->shape().dimensions(2), + rhs_pad->shape().dimensions(3), + testcase.orig_conv_window)) + .ValueOrDie(); + auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + window, dnums) + .ValueOrDie(), + input, rhs_pad, window, dnums)); + + // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place + // after the transformation. + PrecisionConfigProto precision_config; + precision_config.add_operand_precision(PrecisionConfigProto::HIGH); + precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST); + orig_conv->set_precision_config(precision_config); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + if (testcase.expected_conv_window.empty()) { + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); + } else { + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); + auto* conv = module->entry_computation()->root_instruction(); + SCOPED_TRACE(module->ToString()); + ASSERT_THAT(conv, op::Convolution(op::Parameter(), op::Parameter())); + EXPECT_EQ(window_util::ToString(conv->window()), + absl::StrFormat("size=%dx%d %s", + conv->operand(1)->shape().dimensions(2), + conv->operand(1)->shape().dimensions(3), + testcase.expected_conv_window)); + EXPECT_THAT( + conv->precision_config().operand_precision(), + ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST)); + } +} + TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { struct ConvTestOptions { int in_batch = 10; @@ -2006,7 +2433,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options]() -> string { + auto build_and_simplify = [&]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2078,7 +2505,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { auto out_dims = in_dims; out_dims[in_channel_idx] = options.f_output_channels; - auto make_shape = [](tensorflow::gtl::ArraySlice dims, + auto make_shape = [](absl::Span dims, bool minor_to_major_layout) { if (minor_to_major_layout) { return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); @@ -2112,9 +2539,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { root->operand(0)->opcode() == HloOpcode::kDot) { auto lhs_shape = root->operand(0)->operand(0)->shape(); auto rhs_shape = root->operand(0)->operand(1)->shape(); - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", - tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ", + absl::StrJoin(rhs_shape.dimensions(), "x")); } return "UNEXPECTED CHANGE"; }; @@ -2617,6 +3043,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1)); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + HloComputation::Builder builder(TestName()); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1)); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast(root)->iota_dimension(), 2); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector input_spatials; std::vector symmetric_pad_spatials; @@ -2629,11 +3096,10 @@ struct PadReduceWindowEffectiveBroadcastCase { bool should_become_broadcast; string ToTestCaseName() const { - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(input_spatials, ","), ";", - tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", - tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, - ";", should_become_broadcast); + return absl::StrCat(absl::StrJoin(input_spatials, ","), ";", + absl::StrJoin(symmetric_pad_spatials, ","), ";", + absl::StrJoin(reduce_window_spatials, ","), ";", + prepend_a, ";", should_become_broadcast); } }; @@ -2651,8 +3117,8 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { // a and b are parallel bounds we can either turn into a B F S0 S1 or // `B S0 S1 F` kind of pattern. - auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, - int64 a, int64 b) { + auto decorate_spatials = [¶m](absl::Span spatials, int64 a, + int64 b) { std::vector result; if (param.prepend_a) { result.push_back(a); diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 51ebc4763b612884a4453edec5711f78c4006fc3..1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -17,15 +17,15 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -69,8 +69,7 @@ StatusOr AllocationTracker::RegisterInternal( return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer.platform()->Name().c_str()); + backend_->platform()->Name(), shaped_buffer.platform()->Name()); } } @@ -91,8 +90,9 @@ StatusOr AllocationTracker::RegisterInternal( // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer // into a regular ShapedBuffer, which is stored in // handle_to_shaped_buffers_. - handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( - ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); + handle_to_shaped_buffers_[handle].emplace_back( + absl::make_unique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; @@ -124,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { // "handle does not exist". auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } for (auto& shaped_buffer : it->second) { @@ -143,7 +143,7 @@ StatusOr> AllocationTracker::DeconstructTuple( // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { - return InvalidArgument("global data handle %lld is not a tuple", + return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be @@ -200,14 +200,14 @@ StatusOr> AllocationTracker::ResolveInternal( VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } std::vector replicated_buffers; for (const auto& shaped_buffer : it->second) { if (shaped_buffer == nullptr) { - return InvalidArgument( - "global data handle %lld was previously deallocated", data.handle()); + return InvalidArgument("global data handle %d was previously deallocated", + data.handle()); } replicated_buffers.push_back(shaped_buffer.get()); } diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d12be3e007fe0b16ac850d64521f0025d481b5d2..5c180cbdd492031e133b81149f0f4698619b7788 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -111,11 +112,11 @@ StatusOr Backend::BorrowStream(se::StreamExecutor* executor) { return stream_pools_.at(executor).BorrowStream(executor); } -Backend::Backend( - se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, ComputationPlacer* computation_placer, - int intra_op_parallelism_threads) +Backend::Backend(se::Platform* platform, Compiler* compiler, + absl::Span stream_executors, + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), @@ -127,8 +128,8 @@ Backend::Backend( } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = - MakeUnique(platform, stream_executors); + memory_allocator_ = absl::make_unique( + platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; @@ -176,7 +177,7 @@ StatusOr Backend::stream_executor( } } return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); + device_name(device_ordinal)); } StatusOr Backend::devices_equivalent(int device_ordinal_a, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 1bc3796fa48c1627538474d04ef5358ba64dfce9..a2dafbe803f8bd5f23e4e9f3f6d3e6f744c9fab9 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class Backend { // Return a string identifier for the given device, eg: "GPU:3". string device_name(int device_ordinal) const { - return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + return absl::StrCat(platform_->Name(), ":", device_ordinal); } // Returns true if the devices with the given ordinals are equivalent from @@ -149,7 +149,7 @@ class Backend { private: struct EigenThreadPoolWrapper; Backend(se::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice stream_executors, + absl::Span stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 2099916509acdbc2680cc2b5bd405e96f2f7bfb8..a16b85a0a5e3f72f54e9733bb974b01377e0c358 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -63,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + new_dot->set_precision_config(batch_dot->precision_config()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); @@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return true; } -tensorflow::StringPiece BatchDotSimplification::name() const { +absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } @@ -84,10 +86,10 @@ StatusOr BatchDotSimplification::Run(HloModule* module) { bool changed = false; std::vector dot_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), - [](HloInstruction* instr) { - return instr->opcode() == HloOpcode::kDot; - }); + absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); } for (HloInstruction* dot_instr : dot_instrs) { TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f..79d37f08d3553321ebbabc44c8f2488b194954d5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -28,7 +28,7 @@ namespace xla { class BatchDotSimplification : public HloPassInterface { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override; + absl::string_view name() const override; private: StatusOr ElideDegenerateBatchDimensionFromBatchDot( diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index c4cd60c1201f7ddbf0aba4b6d587952531b74bfa..ec281ae68fe76bac4029058997c44b1f7e71aeae 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -33,9 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -43,7 +43,7 @@ namespace xla { namespace { -using tensorflow::gtl::optional; +using absl::optional; // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 7ae202c583516443a6263403fb5460d1adbabd97..76e32174f3ee7d319df6f1f465e19d265d5330f2 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface { rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; - tensorflow::StringPiece name() const override { return "batchnorm_expander"; } + absl::string_view name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index a725351462809e5b670bbf1d79d2dded87e54f07..aba0d9bb5b977d89656580df46838eefb8cd6662 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 1b8b2d204503576c3fcb02f6d5b37f2db45e1768..d63287539dfde5bb4890ab8303ef2205133d8125 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index c9398387098fad84ba28735c30e426fedd9b0cb0..5dcd31b83d24f836d31f44181f39cb8371ca1033 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16ConversionFolding() override = default; - tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + absl::string_view name() const override { return "bfloat16-fold"; } // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 7cf05ca443c00c3b40eeb7d756cf216b45c45c39..6363a21c3bafe8353a6ebfde405bb7a3736c2074 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,8 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + sum, /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b57220cc185fbfaa75d30a0de709cf61ee7..d5b1148058898596bfdb837826a590bbc74e202a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -15,12 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -73,8 +69,7 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { // Inserts conversion HLOs to replace the called computations' BF16 // operands/outputs to F32. Status ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps); + HloInstruction* hlo, absl::Span bf16_called_comps); HloComputation* computation_; const BFloat16Support* bfloat16_support_; @@ -118,8 +113,7 @@ Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( } Status BFloat16NormalizationVisitor::ConvertCalledComputations( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice bf16_called_comps) { + HloInstruction* hlo, absl::Span bf16_called_comps) { std::map cloned_computations; for (auto& comp : bf16_called_comps) { auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); @@ -150,23 +144,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector operand_types(hlo->operand_count()); @@ -380,6 +357,12 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + // TODO(b/112040122): Correctly normalize variadic reduce. + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 2a60fe0af3218484acb95e6c69815d551350764c..30b6346312790f0a199f96f1956ba9ce3e617f72 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16Normalization() override = default; - tensorflow::StringPiece name() const override { return "bf16-normalization"; } + absl::string_view name() const override { return "bf16-normalization"; } // Run BF16 normalization on the given computation. Returns whether the // computation was changed. @@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface { ~BFloat16MixedPrecisionRemoval() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "bf16-mixed-precision-removal"; } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index f9f1f64998f5b925102dc238941897ff6d441b3f..b08705d4c2b644fe1a7ba9994876fd6397f8a5df 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase { StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); @@ -251,8 +252,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 2fb401c4289728f3f59538464c5b8ad49957985b..545a6ecfb1fca88c2c759e820f9d87a38b1941ca 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -407,7 +407,7 @@ void BFloat16Propagation::AdjustCalledComputationParameters( HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { // Adjust parameters. CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 02b8cad089dd8465b7af5c1014e37b77ded6949d..1ee64971ab53e1775294afde1c779369a838008a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface { ~BFloat16Propagation() override = default; - tensorflow::StringPiece name() const override { - return "bfloat16-propagation"; - } + absl::string_view name() const override { return "bfloat16-propagation"; } // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index cfd26fc778cbf9b031fb73259eb76538327b2a6c..8b8c6bfd269971efa6fcd186e4825e6f13bb4094 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -36,20 +38,15 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +namespace { +using absl::StrAppend; +using absl::StrAppendFormat; using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; -using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; - -namespace { template string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -61,12 +58,65 @@ string ColocatedBufferSetsToString(const T& container, const char* title) { return result; } -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace + Status GatherComputationsByAllocationType( const HloModule* module, std::vector* thread_local_computations, @@ -107,7 +157,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s has conflicting allocation requirements (global " "and thread-local)", - computation->name().c_str()); + computation->name()); } if (is_thread_local) { @@ -130,7 +180,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s cannot contain call/while op because it " "requires thread-local buffer allocations", - computation->name().c_str()); + computation->name()); } worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. @@ -147,9 +197,8 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InternalError("Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode())); } } } @@ -169,65 +218,6 @@ Status GatherComputationsByAllocationType( return Status::OK(); } -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); h = tensorflow::Hash64Combine(h, std::hash()(s.offset())); @@ -236,8 +226,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { } string BufferAllocation::Slice::ToString() const { - return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, - ", size:", size_, "}"); + return absl::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); } BufferAllocation::Slice BufferAllocation::GetSlice( @@ -298,7 +288,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); + StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); if (color().value() != 0) { StrAppend(&output, ", color ", color().value()); } @@ -330,11 +320,10 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, absl::StrFormat( + " %s [%d,%d]: %s\n", buffer->ToString(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -427,7 +416,7 @@ StatusOr BufferAssignment::GetUniqueSlice( return FailedPrecondition( "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } } else { VLOG(3) << "No allocation"; @@ -436,7 +425,7 @@ StatusOr BufferAssignment::GetUniqueSlice( if (result.allocation() == nullptr) { return FailedPrecondition( "BufferAllocation::Slice not assigned for instruction %s at index %s", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } return result; } @@ -627,7 +616,7 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.total_allocation_bytes += allocation.size(); } - // Only compute total fragmentation if all computations are sequential. + // Only compute total fragmentation if all computations have schedules. SequentialHloOrdering::HloModuleSequence module_sequence; for (const auto& computation : module_->computations()) { const std::vector* sequence = @@ -648,39 +637,38 @@ Status BufferAssignment::ComputeSummaryStats() { string BufferAssignment::Stats::ToString() const { string s; - Appendf(&s, "BufferAssignment stats:\n"); - Appendf(&s, " parameter allocation: %10s\n", - HumanReadableNumBytes(parameter_allocation_bytes).c_str()); - Appendf(&s, " constant allocation: %10s\n", - HumanReadableNumBytes(constant_allocation_bytes).c_str()); - Appendf(&s, " maybe_live_out allocation: %10s\n", - HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); - Appendf(&s, " preallocated temp allocation: %10s\n", - HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + StrAppendFormat(&s, "BufferAssignment stats:\n"); + StrAppendFormat(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes)); + StrAppendFormat(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes)); + StrAppendFormat(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes)); + StrAppendFormat(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes)); if (preallocated_temp_fragmentation_bytes >= 0) { const double percent = 100. * preallocated_temp_fragmentation_bytes / preallocated_temp_allocation_bytes; - Appendf( + StrAppendFormat( &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), - percent); + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent); } - Appendf(&s, " total allocation: %10s\n", - HumanReadableNumBytes(total_allocation_bytes).c_str()); + StrAppendFormat(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes)); if (total_fragmentation_bytes >= 0) { const double percent = 100. * total_fragmentation_bytes / total_allocation_bytes; - Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes), percent); } return s; } string BufferAssignment::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + absl::StrAppend(&output, "BufferAssignment:\n"); for (auto& allocation : allocations_) { - tensorflow::strings::StrAppend(&output, allocation.ToString()); + absl::StrAppend(&output, allocation.ToString()); } return output; } @@ -1100,8 +1088,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), + HeapSimulator::Run(absl::make_unique( + absl::make_unique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1130,11 +1118,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique( - MakeUnique(alignment)), - *computation, *instruction_sequence, - assignment->points_to_analysis(), - assignment->buffer_size_, options)); + HeapSimulator::Run( + absl::make_unique( + absl::make_unique(alignment)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), assignment->buffer_size_, + options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1646,7 +1635,8 @@ StatusOr> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - // Can't use MakeUnique because BufferAssignment constructor is private. + // Can't use absl::make_unique because BufferAssignment constructor is + // private. std::unique_ptr assignment( new BufferAssignment(module, std::move(liveness), std::move(buffer_size), std::move(color_alignment))); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 94495290c131e22392079dc2d0237d990b646d3e..24ba7c16f548c10f58f41d2b88488939ca2d8e4d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" @@ -41,6 +41,17 @@ limitations under the License. namespace xla { +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations); + // This class abstracts an allocation of contiguous memory which can hold the // values described by LogicalBuffers. Each LogicalBuffer occupies a sub-range // of the allocation, represented by a Slice. A single BufferAllocation may hold diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index eccb146a0d7d628870be179a540d9750df3fe41c..8bd1533972413194dec3609829c8cf8df570cc2a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -37,7 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -79,15 +79,14 @@ const std::vector GetInstructions(HloInstruction* root) { return main_list.GetInstructions(); } -class BufferAssignmentTest : public HloTestBase { +class BufferAssignmentTest : public HloVerifiedTestBase { protected: - BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -98,7 +97,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentNoBuffersForConstants( HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -109,7 +108,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique(module), + module, absl::make_unique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -119,7 +118,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignmentWithInstructionSequence( HloModule* module, - tensorflow::gtl::ArraySlice instruction_sequence, + absl::Span instruction_sequence, int64 alignment = 1) { SequentialHloOrdering::HloModuleSequence module_sequence; module_sequence[module->entry_computation()] = @@ -127,7 +126,8 @@ class BufferAssignmentTest : public HloTestBase { instruction_sequence.end()); return BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), + absl::make_unique(module, + module_sequence), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -147,6 +147,17 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildReduceComputation(const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, param2)); + return builder.Build(); + } + // Builds a simple compare-to-limit (x < 4) computation for a While. // // condition: @@ -163,8 +174,8 @@ class BufferAssignmentTest : public HloTestBase { HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( HloInstruction::CreateGetTupleElement(const4->shape(), param, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4)); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, index, const4)); return builder.Build(); } @@ -311,12 +322,12 @@ TEST_F(BufferAssignmentTest, ScalarConstant) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); } } @@ -335,13 +346,13 @@ TEST_F(BufferAssignmentTest, BufferForConst) { module->AddEntryComputation(builder.Build()); { - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); EXPECT_TRUE(buffers->HasTopLevelAllocation(const0)); EXPECT_TRUE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); } { - auto buffers = RunBufferAssignmentNoBuffersForConstants(module.get()); + auto buffers = RunBufferAssignmentNoBuffersForConstants(module); EXPECT_FALSE(buffers->HasTopLevelAllocation(const0)); EXPECT_FALSE(buffers->HasTopLevelAllocation(const1)); GetAssignedOutputAllocation(*buffers, add); @@ -363,7 +374,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation() // reports for the instruction directly. EXPECT_EQ(buffers->HasTopLevelAllocation(tuple), @@ -386,7 +397,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // The copy node now has an output buffer. GetAssignedOutputAllocation(*buffers, copy); } @@ -400,12 +411,14 @@ TEST_F(BufferAssignmentTest, Basic) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -413,7 +426,7 @@ TEST_F(BufferAssignmentTest, Basic) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -447,12 +460,14 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -472,7 +487,7 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -506,12 +521,14 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -539,7 +556,7 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { return Status::OK(); }; - auto buffers = RunColoredBufferAssignment(module.get(), colorer); + auto buffers = RunColoredBufferAssignment(module, colorer); // Distinct input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -576,12 +593,14 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); auto sub = builder.AddInstruction( @@ -589,7 +608,7 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); // Input buffers were assigned for parameters. BufferAllocation paramscalar_buffer = @@ -640,7 +659,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 size1 = ValidateBuffers(level1, *buffers); @@ -675,10 +694,10 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { // output. (Reuse is not safe in the general case, as it reshapes and some // out-of-order reductions could overwrite an element before a use.) // - // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3) + // param0[100] --- (exp1) --- (exp2) --- (reduce x+y) --- (exp3) auto module = CreateNewModule(); auto reduce_computation = - module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1")); + module->AddEmbeddedComputation(BuildReduceComputation("f32+f32")); auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( @@ -699,7 +718,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const std::vector instrs = GetInstructions(exp3); ValidateBuffers(instrs, *buffers); @@ -755,7 +774,7 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { EXPECT_EQ(8, levelb.size()) << "Invalid nested body size"; // Assigns buffers and fetches sizes. - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); int64 size0 = ValidateBuffers(level0, *buffers); int64 sizec = ValidateBuffers(levelc, *buffers); int64 sizeb = ValidateBuffers(levelb, *buffers); @@ -820,7 +839,7 @@ TEST_F(BufferAssignmentTest, ExampleConditional) { EXPECT_EQ(2, true_instrs.size()); EXPECT_EQ(2, false_instrs.size()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); ValidateBuffers(conditional_instrs, *buffers); ValidateBuffers(true_instrs, *buffers); ValidateBuffers(false_instrs, *buffers); @@ -858,7 +877,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // tanh and exp2 can reuse exp1's buffer EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1)); @@ -887,7 +906,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -920,7 +939,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -957,7 +976,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The instructions should not share buffers. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -992,7 +1011,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1024,7 +1043,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // negate and broadcast should share a buffer. EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast)); @@ -1062,7 +1081,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // The broadcast output buffer cannot be shared. EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast), @@ -1106,7 +1125,7 @@ TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) { HloInstruction::CreateMap(vec_shape, {call}, map_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Allocations for the map computation should be thread-local and not // live-out. @@ -1155,7 +1174,7 @@ TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be four allocations: one for vector of pointers, and one for // each tuple element. @@ -1191,7 +1210,7 @@ TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Only some of the elements of the input param are liveout. EXPECT_FALSE( @@ -1234,7 +1253,7 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); } @@ -1248,7 +1267,7 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) { /*operands=*/{}, /*custom_call_target=*/"foo_function")); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(3, assignment->Allocations().size()); EXPECT_TRUE( @@ -1279,7 +1298,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { HloInstruction::CreateCall(tuple_shape, {param}, sub_computation)); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); EXPECT_EQ(2, assignment->Allocations().size()); // Buffers for call are colocated with the sub-computation. @@ -1341,7 +1360,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { module->AddEntryComputation(std::move(a_computation)); module->AddEmbeddedComputation(std::move(b_computation)); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), @@ -1377,7 +1396,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Bitcast should get the same allocation as the param. EXPECT_EQ(1, assignment->Allocations().size()); @@ -1404,7 +1423,7 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // Select shallow copies one of its operands so it defines its own top-level // buffer and receives its own allocation. @@ -1442,7 +1461,7 @@ TEST_F(BufferAssignmentTest, TupleBufferNotReused) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(module); // There should be no buffer reuse. The copy should not reuse the tuple // buffer. @@ -1476,12 +1495,12 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { auto dot_bc = builder.AddInstruction( HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( - HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); + HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); // Run buffer assignment with alignment=1. auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto assignment = RunBufferAssignment(module.get(), /*alignment=*/1); + auto assignment = RunBufferAssignment(module, /*alignment=*/1); // There are 5 allocations: 3 parameters, 1 output, and 1 temp. EXPECT_EQ(5, assignment->Allocations().size()); @@ -1500,7 +1519,7 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { EXPECT_EQ(80, slice_bc.allocation()->size()); // Re-run buffer assignment with alignment=64. - assignment = RunBufferAssignment(module.get(), /*alignment=*/64); + assignment = RunBufferAssignment(module, /*alignment=*/64); EXPECT_EQ(5, assignment->Allocations().size()); slice_ab = assignment->GetUniqueTopLevelSlice(dot_ab).ConsumeValueOrDie(); slice_bc = assignment->GetUniqueTopLevelSlice(dot_bc).ConsumeValueOrDie(); @@ -1531,12 +1550,14 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto builder = HloComputation::Builder(TestName()); auto paramscalar = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( - f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); + f32vec100_, HloOpcode::kMultiply, broadcast, param0)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -1544,16 +1565,13 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); - // Trivially, the set of peak memory logical buffer(s) of an allocation with a - // single logical buffer should be exactly the logical buffer in that - // allocation. const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); const std::vector& peak_buffers = mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); - EXPECT_EQ(peak_buffers[0]->instruction(), mul); + EXPECT_EQ(peak_buffers[0]->instruction(), broadcast); } TEST_F(BufferAssignmentTest, PeakBuffers) { @@ -1589,7 +1607,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { module->AddEntryComputation(builder.Build()); auto buffers = RunBufferAssignmentWithInstructionSequence( - module.get(), {param, log, rev, neg, concat, root}); + module, {param, log, rev, neg, concat, root}); // The temporary buffer should hold the 4 interior instructions. const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, concat); @@ -1645,7 +1663,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) { ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); module->AddEntryComputation(builder.Build()); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(module); const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); const std::vector& peak_buffers = buffer.PeakMemoryLogicalBuffers(); @@ -1695,15 +1713,13 @@ ENTRY main { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(hlo_text)); - + ParseAndVerifyModule(hlo_text); HloInstruction* constant_1 = - module->entry_computation()->GetInstructionWithName("constant.1.1"); + module().entry_computation()->GetInstructionWithName("constant.1.1"); HloInstruction* constant_2 = - module->entry_computation()->GetInstructionWithName("constant.1.2"); + module().entry_computation()->GetInstructionWithName("constant.1.2"); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); { const BufferAllocation& allocation_for_const_1 = @@ -1732,7 +1748,7 @@ ENTRY main { } } -class WhileBufferAssignmentTest : public HloTestBase { +class WhileBufferAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr BuildWhileConditionComputation( const string& name) { @@ -1769,7 +1785,8 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, xla::MakeUnique(module, sequence), + module, + absl::make_unique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1805,9 +1822,9 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1831,8 +1848,8 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Verify 'input0' and read-only use while0{0} alias. EXPECT_EQ(assignment->GetUniqueSlice(input0, {}).ConsumeValueOrDie(), @@ -1888,20 +1905,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* param = - module->entry_computation()->parameter_instruction(0); + module().entry_computation()->parameter_instruction(0); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1909,7 +1926,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_param, assignment->GetUniqueSlice(param, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -1956,20 +1973,20 @@ ENTRY %test_module { ROOT %bcast = s32[1024,1024]{1,0} broadcast(s32[] %while.1), dimensions={} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); + ParseAndVerifyModule(module_str); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. - int64 instruction_count = module->instruction_count(); + int64 instruction_count = module().instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); - ASSERT_EQ(instruction_count, module->instruction_count()); + ASSERT_IS_OK(copy_insertion.Run(&module()).status()); + ASSERT_EQ(instruction_count, module().instruction_count()); // Get the instructions in the module. - const HloInstruction* bcast = module->entry_computation()->root_instruction(); + const HloInstruction* bcast = + module().entry_computation()->root_instruction(); const HloInstruction* constant = - module->entry_computation()->GetInstructionWithName("constant.42"); + module().entry_computation()->GetInstructionWithName("constant.42"); ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); const HloInstruction* while1 = bcast->operand(0); ASSERT_EQ(while1->opcode(), HloOpcode::kWhile); @@ -1977,7 +1994,7 @@ ENTRY %test_module { ASSERT_EQ(while0->opcode(), HloOpcode::kWhile); // Run buffer assignment. - auto assignment = RunBufferAssignment(module.get()); + auto assignment = RunBufferAssignment(&module()); TF_ASSERT_OK_AND_ASSIGN(auto slice_constant, assignment->GetUniqueSlice(constant, {})); TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, @@ -2070,7 +2087,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // any copies inserted for BufferAssignment to run. int64 instruction_count = module->instruction_count(); CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_IS_OK(copy_insertion.Run(module).status()); ASSERT_EQ(instruction_count, module->instruction_count()); // Create a sequential order among all the instructions in the entry @@ -2082,8 +2099,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { TF_ASSERT_OK_AND_ASSIGN( auto assignment, BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), sequence), + module, absl::make_unique(module, sequence), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, @@ -2120,7 +2136,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2141,8 +2157,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // while0 and while1 buffers should be completely aligned. EXPECT_EQ(assignment->GetUniqueSlice(while0, {0}).ConsumeValueOrDie(), @@ -2184,13 +2200,13 @@ TEST_F(BufferAssignmentTest, TwoCalls) { { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); } - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } @@ -2214,15 +2230,14 @@ ENTRY Main { } )"; - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - HloRunner::CreateModuleFromString( - hlo_text, legacy_flags::GetDebugOptionsFromFlags())); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + ParseAndVerifyModule(hlo_text, config); - auto buffers = RunBufferAssignment(module.get()); + auto buffers = RunBufferAssignment(&module()); - HloComputation* main = module->entry_computation(); - HloComputation* callee = module->GetComputationWithName("Callee"); + HloComputation* main = module().entry_computation(); + HloComputation* callee = module().GetComputationWithName("Callee"); EXPECT_NE(callee, nullptr); HloInstruction* param0 = callee->parameter_instruction(0); @@ -2282,14 +2297,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto weights0 = builder.AddInstruction( HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto input1 = builder.AddInstruction( HloInstruction::CreateParameter(2, data_shape_, "input1")); auto weights1 = builder.AddInstruction( HloInstruction::CreateParameter(3, data_shape_, "weights1")); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, one, {1})); + HloInstruction::CreateBroadcast(data_shape_, one, {})); auto cond = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2309,18 +2324,18 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); auto gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); - auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, gte0, gte1)); + auto root_add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, gte0, gte1)); module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module)); EXPECT_TRUE(result); } - RunCopyInsertion(module.get()); + RunCopyInsertion(module); auto sequence = ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); @@ -2339,8 +2354,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto assignment = BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), sequence), + module, absl::make_unique(module, sequence), ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) @@ -2361,9 +2375,9 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { auto zero = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + HloInstruction::CreateBroadcast(data_shape_, zero, {})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -2394,8 +2408,8 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); module->AddEntryComputation(builder.Build()); - RunCopyInsertion(module.get()); - auto assignment = RunBufferAssignment(module.get()); + RunCopyInsertion(module); + auto assignment = RunBufferAssignment(module); // Get BufferAllocation for root instruction. auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) .ConsumeValueOrDie() diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 810d597e730c1823668c81598df6138655e58b55..9b2783a214a686f3148723d19bbc94421fc8b4e4 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() { string BufferLiveness::ToString() const { std::vector pieces; - pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", - module_->name().c_str())); + pieces.push_back( + absl::StrFormat("BufferLiveness(module=%s):", module_->name())); pieces.push_back("HloOrdering:"); pieces.push_back(hlo_ordering_->ToString()); - pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + pieces.push_back("Aliased buffers:"); for (const LogicalBuffer* buffer : aliased_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + pieces.push_back("Live out buffers:"); for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b)); if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { return false; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 4a927b57674345f8b3493c098778182a299c5902..26e26e316d6281a97f8317f8ed1d7a6f21b0d374 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -18,8 +18,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -119,8 +120,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +168,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -215,8 +216,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -249,8 +250,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -293,10 +294,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { SequentialHloOrdering::HloModuleSequence module_sequence; std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -342,10 +343,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { std::vector order = {param, add, recv, recv_done, send, send_done}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -376,8 +377,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -412,8 +413,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -453,8 +454,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -518,8 +519,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -580,8 +581,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - // Runs BufferLiveness on this computation. - // Returns whether buffer interference is detected between tuple-shaped - // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1, - const bool fuse_gte0 = false) { + std::unique_ptr BuildModule(const bool update_uses_tuple_element1, + const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); // Create output tuple. - auto tuple_root = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = CreateNewModule(); - module->AddEntryComputation(BuildDummyComputation()); - auto* computation = module->AddEmbeddedComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); + auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -666,16 +664,39 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { computation->CreateFusionInstruction({gte0}, HloInstruction::FusionKind::kLoop); } + return module; + } + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } + bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); + // Run BufferLiveness on 'module'. + auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie(); + auto hlo_ordering = absl::make_unique(module.get()); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); + return hlo_ordering->MayInterfere( + dataflow->GetUniqueValueAt(tuple_param0, {1}), + dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow); + } }; // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); + EXPECT_FALSE( + RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases @@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); + EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false, + /*fuse_gte0=*/true)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); + EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true)); } class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { @@ -780,10 +806,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index 2bc556a9e270136f5f3eaf2433f8c96eeeaea0a2..fdf822c666b15afbc7553ca89d4f92ab08201869 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -17,11 +17,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h index f4be16e0843f64f41ef27539bf263ae98ce0ebf9..69b36463560a1fad4f62687e9014fb3fbe5bbd13 100644 --- a/tensorflow/compiler/xla/service/buffer_value.h +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 985ff30e80ac9c41c024e4c4d2d0ebb3cff75167..23b2a327096dfdb3c756a4acc5476ec01dcac1b3 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,21 +17,21 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::StrCat; +using absl::StrAppendFormat; +using absl::StrCat; string CallContextToString(CallContext context) { switch (context) { @@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { } string CallSite::ToString() const { - return StrCat(instruction()->name(), " calls in context ", - CallContextToString(context()), ": ", - tensorflow::str_util::Join( - called_computations(), ", ", + return StrCat( + instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + absl::StrJoin(called_computations(), ", ", [](string* out, const HloComputation* computation) { out->append(computation->name()); })); @@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() { /* static */ std::unique_ptr CallGraph::Build(const HloModule* module) { - // Constructor for CallGraph is private so MakeUnique can't be used. - auto call_graph = WrapUnique(new CallGraph(module)); + // Constructor for CallGraph is private so absl::make_unique can't be used. + auto call_graph = absl::WrapUnique(new CallGraph(module)); VLOG(2) << "Building call graph for:"; XLA_VLOG_LINES(2, module->ToString()); @@ -356,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, string CallGraph::ToString() const { string out; - Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + StrAppendFormat(&out, "Call graph for module %s:\n", module_->name()); for (const CallGraphNode& node : nodes()) { - Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); - Appendf(&out, " calls:\n"); + StrAppendFormat(&out, "Computation %s:\n", node.computation()->name()); + StrAppendFormat(&out, " calls:\n"); for (const HloComputation* callee : node.callees()) { - Appendf(&out, " %s\n", callee->name().c_str()); + StrAppendFormat(&out, " %s\n", callee->name()); } - Appendf(&out, " called by:\n"); + StrAppendFormat(&out, " called by:\n"); for (const HloComputation* caller : node.callers()) { - Appendf(&out, " %s\n", caller->name().c_str()); + StrAppendFormat(&out, " %s\n", caller->name()); } - Appendf(&out, " callsites:\n"); + StrAppendFormat(&out, " callsites:\n"); for (const CallSite& callsite : node.callsites()) { - Appendf(&out, " %s\n", callsite.ToString().c_str()); + StrAppendFormat(&out, " %s\n", callsite.ToString()); } } return out; diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 97d3811508adee1bf2d0942bcc69e3e34a41c8c3..3af2ab5edfd9faf4ac5193df4b823c21b55b2f7f 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -15,8 +15,8 @@ limitations under the License. // Call graph for an HLO module. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ #include @@ -272,4 +272,4 @@ class CallGraph { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 256d05a73e0bf61d959d21795c106286b52d0b19..1d4214044409ae06239506e610000c839450a030 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( "Could not find mapping from subcomputation HLO %s to a cloned HLO.", - subcomputation_hlo->ToString().c_str()); + subcomputation_hlo->ToString()); } return it->second; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index a8345a394d46c90a48305313dac0bcd9b06938ac..c5cd88b9ea2a9c308786d4d7476316b1e592d40a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ #include @@ -35,11 +35,11 @@ class CallInliner : public HloPassInterface { static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; - tensorflow::StringPiece name() const override { return "CallInliner"; } + absl::string_view name() const override { return "CallInliner"; } StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index ff968bca297077c7cf869ff8d2becb8bf739dce3..5d85a3f173d50a964420e720f5c9b416731d948c 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace op = xla::testing::opcode_matchers; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 13008efed1494402eaff47904c2e4797334381a1..3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::HOST_TO_DEVICE) { return FailedPrecondition( "host-to-device channels cannot be used with a Send operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " - "by a sender: %lld", + "by a sender: %d", handle.handle()); } channel.has_sender = true; @@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::DEVICE_TO_HOST) { return FailedPrecondition( "device-to-host channels cannot be used with a Recv operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } @@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (channel.receiver_count >= 1) { return FailedPrecondition( "when registering recv, passed a channel handle that is already used " - "by a receiver: %lld", + "by a receiver: %d", handle.handle()); } channel.receiver_count += 1; diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index d773558c284a7d645f2766bb88c50f7da3777e5d..52037bf9b52556c6aa2e66dd3209e25cf085cfe3 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 7426672a7a2a9102bd5ea98bd51092982e1e09b4..e5a6c28478a7ebf87878c3937069f15cafe12615 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -62,7 +62,7 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata) { std::vector> hlo_modules; @@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime( if (!directory_path.empty()) { HloSnapshot hlo_snapshot; *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = tensorflow::strings::StrCat( - "computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); + string filename = + absl::StrCat("computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 1ac950bdd66bd034dfdafa8598ec506221e99c2f..61136a3e11fe15fb74eac257f46292c6cd24ce7d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -50,12 +50,12 @@ class CompileOnlyService : public Service { // |CompileOnlyClient::CompileAheadOfTime| for additional details. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options); StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const absl::Span computations, const AotCompilationOptions& options, std::unique_ptr* metadata); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6b3b9820f09803c8a04504e6c35c22de51abf04b..687ecafe0c308ecc22857fae650c6998677f605d 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() { return NotFound( "could not find registered compiler for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 34f7fe12cac5a4dcd3822865bee903d6eabc25c0..1fdda31c34a17a16f75e1efada542c2c2ea15038 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index cb61f3da39fb8eef69fd81066d87a1da91a62935..af8f7f1027a40703137d6880a9865449c560a47b 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -52,9 +52,8 @@ string ComputationLayout::ToString() const { for (auto& param_layout : parameter_layouts_) { params.push_back(param_layout.ToString()); } - return tensorflow::strings::StrCat("(", - tensorflow::str_util::Join(params, ", "), - ") => ", result_layout_.ToString()); + return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ", + result_layout_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 187ce568cbb6c6666e978b8c8114262313c70ba5..2210a8578ad73efb27dc9c230b142c55228d2af5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,8 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; namespace xla { @@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { "computation_count=%d", proto.replica_count(), proto.computation_count()); } - auto assignment = MakeUnique(proto.replica_count(), - proto.computation_count()); + auto assignment = absl::make_unique( + proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); @@ -132,7 +132,7 @@ StatusOr ComputationPlacer::AssignDevices( return NotFound( "could not find registered computation placer for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.placer == nullptr) { @@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() { } // namespace xla static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index b7be3ba605a89a736b032eaab5a5085ac64fc549..4ea3a13f2835c5fef99c274f14d7d683c9ff5fc8 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 063261e26d06e21a297e8e3c405898a17221b7ca..3de50cbd7ff752e8722a103b68f75144c6c889cd 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -27,9 +27,7 @@ namespace xla { // with their true or false computation as appropriate. class ConditionalSimplifier : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "simplify-conditional"; - } + absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..9c81a86bbb9dc7078237fe200f510a4905cb4d8d --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -0,0 +1,249 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// ConvolutionVisitor traverses the HLO computation and rewrites Convolution +// operations with feature_group_count > 1 into convolutions with +// feature_group_count = 1. +class ConvolutionVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* convolution) override; + + // Runs the visitor on a computation. + static bool Run(HloComputation* computation); + + // Returns whether any convolution ops were rewritten. + const bool changed() const { return changed_; } + + ~ConvolutionVisitor() override = default; + + private: + explicit ConvolutionVisitor(HloComputation* computation) + : computation_(computation) {} + + // Current HloComputation instance the ConvolutionVisitor is traversing. + HloComputation* computation_; + + // Whether rewrite has occurred. + bool changed_ = false; +}; + +bool ConvolutionVisitor::Run(HloComputation* computation) { + ConvolutionVisitor visitor(computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +Shape ExpandedFilterShape(const Shape& shape, int64 group_count, + int64 input_feature_dim) { + int64 num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); + Shape expanded_shape = shape; + expanded_shape.set_dimensions( + input_feature_dim, shape.dimensions(input_feature_dim) * group_count); + return expanded_shape; +} + +// Returns a vector with 'group_count' many groups, where the i-th group +// consists of 'group_size' times the value i. +std::vector GetMaskIds(int64 group_size, int64 group_count) { + std::vector values; + for (int i = 0; i < group_count; ++i) { + for (int j = 0; j < group_size; ++j) { + values.push_back(i); + } + } + return values; +} + +// Create a mask for grouped convolution that will make a normal convolution +// produce the same results as a grouped convolution. For a [2, 1, 6] +// filter this returns a [2, 3, 6] mask +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 +// 0 0 1 1 0 0 +// 0 0 0 0 1 1 +// +// The first step is to create a rank 1 constant: +// 0 1 2 +// +// This is broadcasted to +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// 0 0 0 0 0 0 +// 1 1 1 1 1 1 +// 2 2 2 2 2 2 +// +// Then we create another rank 1 constant +// 0 0 1 1 2 2 +// +// This is broadcasted to +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// 0 0 1 1 2 2 +// +// Finally we use the Eq op of these two broadcasted constants and get the +// desired mask. +HloInstruction* GetExpandedFilterMask( + const Shape& filter_shape, int64 input_feature_dim, + int64 output_feature_dim, int64 group_count, + const std::function)>& + add_instruction) { + Shape expanded_filter_shape = + ExpandedFilterShape(filter_shape, group_count, input_feature_dim); + Shape mask_shape = ShapeUtil::MakeShape( + S32, AsInt64Slice(expanded_filter_shape.dimensions())); + int64 output_feature = filter_shape.dimensions(output_feature_dim); + int64 group_size = filter_shape.dimensions(input_feature_dim); + + // Create a 'input_feature' sized linspace and 'output_feature' sized linspace + // that will be broadcasted into perpendicular dimensions and compared. + const std::vector input_feature_filter_mask = + GetMaskIds(group_size, group_count); + const std::vector output_feature_filter_mask = + GetMaskIds(output_feature / group_count, group_count); + + auto mask1 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(input_feature_filter_mask))); + auto broadcasted_mask1 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim})); + auto mask2 = add_instruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(output_feature_filter_mask))); + auto broadcasted_mask2 = add_instruction( + HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim})); + + // Compare the broadcasted output feature linspace to the input feature + // linspace to create a diagonal predicate. + Shape predicate_shape = ShapeUtil::MakeShape( + PRED, AsInt64Slice(expanded_filter_shape.dimensions())); + return add_instruction(HloInstruction::CreateBinary( + predicate_shape, HloOpcode::kEq, broadcasted_mask1, broadcasted_mask2)); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + int64 group_count = convolution->feature_group_count(); + if (group_count == 1) { + return Status::OK(); + } + auto filter = convolution->mutable_operand(1); + changed_ = true; + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + + auto dim_numbers = convolution->convolution_dimension_numbers(); + int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension(); + int64 group_size = filter->shape().dimensions(input_feature_dim); + int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension(); + auto expanded_filter_shape = + ExpandedFilterShape(filter->shape(), group_count, input_feature_dim); + HloInstruction* filter_mask = GetExpandedFilterMask( + filter->shape(), input_feature_dim, output_feature_dim, group_count, add); + HloInstruction* expanded_filter; + // We want to repeat 'filter' in the 'input_feature_dim' dimension + // 'group_count' times. + if (group_size == 1) { + Shape reshaped_filter_shape = + ShapeUtil::DeleteDimension(input_feature_dim, filter->shape()); + auto reshaped_filter = + add(HloInstruction::CreateReshape(reshaped_filter_shape, filter)); + std::vector broadcast_dims; + for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) { + if (i == input_feature_dim) { + continue; + } + broadcast_dims.push_back(i); + } + expanded_filter = add(HloInstruction::CreateBroadcast( + expanded_filter_shape, reshaped_filter, broadcast_dims)); + } else { + // We could possibly also use reshape, broadcast, reshape instead of concat + // here, but it would require more complex code, and for depthwise + // convolution we would never end up in this branch. + std::vector concat_operands(group_count, filter); + expanded_filter = add(HloInstruction::CreateConcatenate( + expanded_filter_shape, concat_operands, input_feature_dim)); + } + auto zero = add(HloInstruction::CreateConstant(absl::make_unique( + LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero_filter = + add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); + auto new_filter = add( + HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect, + filter_mask, expanded_filter, zero_filter)); + auto new_convolution = HloInstruction::CreateConvolve( + convolution->shape(), convolution->mutable_operand(0), new_filter, + convolution->window(), dim_numbers, /*feature_group_count=*/1); + new_convolution->set_precision_config(convolution->precision_config()); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(new_convolution))); + return Status::OK(); +} + +} // namespace + +StatusOr ConvolutionFeatureGroupConverter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), before:\n" + + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (ConvolutionVisitor::Run(comp)) { + changed = true; + } + } + XLA_VLOG_LINES(2, "ConvolutionFeatureGroupConverter::Run(), after:\n" + + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..498894737fa37a6d8cca6ead2a86c72eb84ababd --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +// A pass which rewrites convolutions with feature_group_count > 1 into +// convolutions with feature_group_count = 1. +class ConvolutionFeatureGroupConverter : public HloPassInterface { + public: + ConvolutionFeatureGroupConverter() {} + + absl::string_view name() const override { + return "convolution-feature-group-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28373ebf636c7b6b3059dcf6cd931901ebc87fc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using ConvolutionFeatureGroupConverterTest = HloTestBase; +namespace op = testing::opcode_matchers; + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountEqualToInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2,2] { + %input = f32[1,2,2]{2,1,0} parameter(0) + %copy = f32[1,2,2]{2,0,1} copy(f32[1,2,2]{2,1,0} %input) + %filter = f32[1,1,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,2]{2,0,1} %copy, f32[1,1,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Constant()))); +} + +TEST_F(ConvolutionFeatureGroupConverterTest, + ConvertFeatureGroupCountDivisorOfInputFeatureDim) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2,2] { + %input = f32[1,2,4]{2,1,0} parameter(0) + %copy = f32[1,2,4]{2,0,1} copy(f32[1,2,4]{2,1,0} %input) + %filter = f32[1,2,2]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,2]{2,0,1} convolution(f32[1,2,4]{2,0,1} %copy, f32[1,2,2]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=2 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + ConvolutionFeatureGroupConverter converter; + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Make sure the convolution is converted to one with feature_group_count = 1. + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + EXPECT_EQ(root->feature_group_count(), 1); + // Verify that the filter operand has been replaced. + EXPECT_THAT(root->operand(1), + op::Select(op::Eq(op::Broadcast(op::Constant()), + op::Broadcast(op::Constant())), + // We expect to see Concatenate here instead of + // Broadcast, because feature_group_count < input + // feature dimension. + op::Concatenate(op::Parameter(), op::Parameter()), + op::Broadcast(op::Constant()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 3e39c1bab1e07d192a8c145be5103085fd3c189b..b65dfef9c9575b683b2656af2ccc151d87db2cd7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -31,18 +33,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { - -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { +using absl::StrAppend; + bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && @@ -381,7 +378,7 @@ class CopyRemover { } string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); StrAppend(&out, " Buffer values, in dependency order:\n"); for (const HloBuffer& buffer : alias_analysis_.buffers()) { StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); @@ -482,7 +479,7 @@ class CopyRemover { // 'values' an entry is created in value_to_node which indicates the // respective ValueNode representing that value. void AddValueList( - tensorflow::gtl::ArraySlice values, + absl::Span values, tensorflow::gtl::FlatMap* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; @@ -863,16 +860,16 @@ class CopyRemover { for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { values.push_back(p->value); } - return StrCat("{", - Join(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = StrCat("BufferValueTracker:\n"); + string out = absl::StrCat("BufferValueTracker:\n"); StrAppend(&out, " Def-use chains in each buffer:\n"); for (const ValueNode* head : value_lists_) { StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), @@ -880,10 +877,10 @@ class CopyRemover { const ValueNode* p = head; do { StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), "\n"); p = p->next; @@ -960,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { return Status::OK(); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// +Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + return AddSpecialCaseCopies(*call_graph, module); +} + Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, @@ -1065,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } - // Special case copies are not eligible for later copy elision passes. - indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { - if (has_copy) { - HloInstruction* copy = *copies_added.mutable_element(index); - if (copy != nullptr) { - copy->SetCopyElisionAllowed(false); - } - } - }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1081,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } -Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } @@ -1101,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, std::unique_ptr call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - instruction->CopyElisionAllowed()) { + if (instruction->opcode() == HloOpcode::kCopy) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } } @@ -1168,10 +1150,10 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + DependencyHloOrdering dep_ordering(module); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1179,7 +1161,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + TF_DCHECK_OK( + VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); MaybeDumpModule("after copy insertion", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5ba64b78a3c9aff5f323691df2ece9b5e6bf3232..d308f6bc84670b78b9cab476f2893bce267df2cf 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -45,7 +45,7 @@ namespace xla { // InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } // fusion_can_share_buffer: backend specific function that decides whether a // fusion can share buffer with its operand. @@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface { Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module); - private: - // Verifies that no HLO values have interfering live ranged assuming the - // ordering used by copy insertion. - Status VerifyNoLiveRangeInterference(HloModule* module); + // Add copies to address special constraints on the roots of computations not + // related to live range interference: + // + // (1) Entry computation root must be unambiguous and distinct. + // + // (2) Any computation called by a kCall instruction must have an + // unambiguous root. + // + // (3) Constants and parameters cannot be live out of the entry computation + // + Status AddSpecialCaseCopies(HloModule* module); - Status AddCopiesToResolveInterference(HloModule* module); + // Verifies that no HLO values have interfering live ranges using the given + // ordering. + Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module); + private: + // Override which requires the caller to pass in a call graph. Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + Status AddCopiesToResolveInterference(HloModule* module); + // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 84779c60b0c790cd0be98f8b42996acf7e3cfd8b..d412578619e5d23db3933af19d665cf8beb4d622 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,6 +50,8 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -62,6 +64,7 @@ cc_library( "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -85,6 +88,10 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ":target_machine_features", + "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", @@ -101,6 +108,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", + "//tensorflow/compiler/xla/service:convolution_feature_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", @@ -177,6 +185,7 @@ cc_library( ":runtime_single_threaded_conv2d", ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", + "@com_google_absl//absl/memory", "@llvm//:execution_engine", "@llvm//:core", "@llvm//:mc", # fixdeps: keep @@ -228,6 +237,9 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:orc_jit", ], ) @@ -270,11 +282,15 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -319,6 +335,8 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -329,12 +347,12 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -361,6 +379,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -381,6 +400,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -394,6 +414,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", @@ -417,6 +438,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:analysis", "@llvm//:core", "@llvm//:ipo", @@ -445,6 +467,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -633,6 +656,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -647,6 +672,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -741,6 +768,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -809,6 +837,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -845,6 +875,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -892,6 +923,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc index 408fe0f5bf5d729165eadd532d4740211620645d..1942ea1a2af8a349de53bafe80977436f9740fc4 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc @@ -40,7 +40,7 @@ std::vector CreateBufferInfosFromBufferAssignment( } std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice buffer_infos) { + absl::Span buffer_infos) { std::vector result; for (int64 i = 0; i < buffer_infos.size(); i++) { if (buffer_infos[i].is_entry_parameter()) { diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h index 05de70c72686dcbdaf0b47c46cde23ed45abdb42..e9ee928ab290f2f5338bd7b3804dc43033e2042f 100644 --- a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h +++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -34,7 +34,7 @@ CreateBufferInfosFromBufferAssignment( // If this function returns V then entry parameter i has buffer allocation index // V[i]. std::vector CreateArgIndexTableFromBufferInfos( - tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo> + absl::Span buffer_infos); } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 128eea4828b5e514b2ba6b398898e4a5d228e746..73b03440cbb936017257b8a92f16dcc25d41e21c 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -35,7 +36,6 @@ limitations under the License. #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { llvm::Triple target_triple(target_machine_->getTargetTriple()); auto target_library_info_impl = - MakeUnique(target_triple); + absl::make_unique(target_triple); target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); passes->add( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 0985b9297fe487f3523826cb0978c17775549735..098ce17a568fd3fb531020e7731100fabda43721 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,6 +132,7 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); + new_conv->set_precision_config(hlo->precision_config()); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index e6fd1499edd0095395194200a5b444ad61e7e39d..59437e88af27528654a0af86baf69ec7a1e91d60 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface { : target_machine_features_(*target_machine_features) {} ~ConvCanonicalization() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-canonicalization"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 35154af0482a2ea85f17add48b16de4a8bb6affb..796f36510e414cde692208cfe0cf9626acae63d3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,8 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -42,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/convolution_feature_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" @@ -100,8 +102,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { @@ -234,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map* hlo_to_profile_idx_; const std::unordered_map& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -258,11 +258,13 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(&target_machine_features); + pipeline.AddPass(); + pipeline.AddPass(target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass( /*rewrite_training_op=*/true, @@ -276,7 +278,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. - pipeline.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -289,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass(); pipeline.AddPass( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -307,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass("after layout assignment"); + after_layout_assn.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -320,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } + pipeline.AddPass(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -333,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -348,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. @@ -451,7 +479,7 @@ Status CreateHloProfilingArtifacts( computation_to_profile_idx, std::unique_ptr* hlo_profile_index_map, std::unique_ptr* hlo_profile_printer_data) { - *hlo_profile_index_map = MakeUnique(module); + *hlo_profile_index_map = absl::make_unique(module); const HloComputation& entry_computation = *module.entry_computation(); TF_ASSIGN_OR_RETURN( @@ -518,11 +546,11 @@ StatusOr> CpuCompiler::RunBackend( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = xla::MakeUnique(); + auto llvm_context = absl::make_unique(); auto llvm_module = - xla::MakeUnique("__compute_module", *llvm_context); + absl::make_unique("__compute_module", *llvm_context); - auto jit = xla::MakeUnique( + auto jit = absl::make_unique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -560,16 +588,15 @@ StatusOr> CpuCompiler::RunBackend( ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. + // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module.get(), + absl::make_unique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -677,8 +704,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", - error.c_str()); + return InternalError("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; @@ -714,7 +740,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr target_machine = WrapUnique( + std::unique_ptr target_machine = absl::WrapUnique( target->createTargetMachine(triple.getTriple(), cpu_name, features, CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, opt_level)); @@ -755,7 +781,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, std::unique_ptr assignment, BufferAssigner::Run( module, - xla::MakeUnique(module, module_sequence), + absl::make_unique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -849,7 +875,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); - results.emplace_back(MakeUnique( + results.emplace_back(absl::make_unique( std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } @@ -872,7 +898,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::host::kHostPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 04e1c48872ed55ca7f2aa3bec08c44a1666b90f1..f2af923782df268e3e6da3895ec35579ab6aa51f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,13 +18,14 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index 3313d1e6eb71bff39f509c3d24858568df786422..d49f7d7cc2d9b1d00847feda62fa62dd740820d8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,11 +32,11 @@ namespace xla { // (module-scoped). class CpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c376864c3e1f882e11bc05f8cf93f2fb1c88e4ec..29abf38e439d919ff93629ed992cb3ff93a929bd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -75,9 +75,9 @@ CpuExecutable::CpuExecutable( StatusOr, std::vector>> -CpuExecutable::CreateTempArray( +CpuExecutable::CreateBufferTable( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::vector unowning_buffers( assignment_->Allocations().size()); std::vector owning_buffers( @@ -136,19 +136,19 @@ CpuExecutable::CreateTempArray( Status CpuExecutable::ExecuteComputeFunction( const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, + absl::Span buffers, HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, - // void** temps_array) + // void** buffer_table) // // result: Points at the result. // run_options: the ExecutableRunOptions object. // args_array: null - // temps_array: An array of pointers, containing pointers to temporary buffers - // required by the executable adn pointers to entry computation - // parameters. + // buffer_table: An array of pointers, containing pointers to temporary + // buffers required by the executable adn pointers to entry computation + // parameters. // uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -171,20 +171,19 @@ Status CpuExecutable::ExecuteComputeFunction( void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; - VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " - "uint64 profile_counters[%zu])", + VLOG(3) << absl::StrFormat( + " func(void* result, void* params[null], void* buffer_table[%u], " + "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); - VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + VLOG(3) << absl::StrFormat(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, absl::StrFormat("%p", p)); }; VLOG(3) << " params = nullptr"; - VLOG(3) << tensorflow::strings::Printf( - " temps = [%s]", - tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); - VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters); + VLOG(3) << absl::StrFormat( + " buffer_table = [%s]", + absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), @@ -209,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers) { + absl::Span buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( /*on_host_shape=*/result_shape(), @@ -247,7 +246,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { TF_ASSIGN_OR_RETURN( auto result, @@ -258,7 +257,7 @@ StatusOr CpuExecutable::ExecuteOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -269,7 +268,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -283,11 +282,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( std::vector unowning_buffers; TF_ASSIGN_OR_RETURN( std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); + CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), + arguments)); - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, - CreateResultShapedBuffer(run_options, &owning_buffers)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, absl::MakeSpan(owning_buffers))); // At this point, `unowning_buffers` contains unowning pointers to all of our // buffers, and `buffers` contains owning pointers to the non-live-out @@ -300,7 +300,7 @@ StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( // // We also need to change the types of some of the variables we capture: // run_options needs to change from a pointer to a value type, and arguments - // needs to change from an ArraySlice into a vector. We use a struct instead + // needs to change from a Span into a vector. We use a struct instead // of a lambda to make this explicit. struct AsyncRunTask { CpuExecutable* executable; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 96e53de57eee013fe6f847c10e23a38f5beb9adc..3c3c047bfe8ee0d1ad90ede2432a86264f47870b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -57,12 +57,12 @@ class CpuExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -74,9 +74,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); // Type of the computation function we expect in the JIT. - using ComputeFunctionType = void (*)( - void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); + using ComputeFunctionType = + void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*buffer_table*/, + int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -92,18 +93,18 @@ class CpuExecutable : public Executable { // exists) must out-live the task. StatusOr ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile); - // Creates an array suitable for passing as the "temps" argument to the JIT - // compiled function pointer. + // Creates an array suitable for passing as the "buffer_table" argument to the + // JIT compiled function pointer. // // Returns (unowning_buffers, owning_buffers) where: // - // - unowning_buffers.data() can be passed as the temps argument as-is and - // includes pointers to the scratch storage required by the computation, - // the live-out buffer into which the result will be written and entry - // computation parameters. + // - unowning_buffers.data() can be passed as the buffer_table argument as-is + // and includes pointers to the scratch storage required by the + // computation, the live-out buffer into which the result will be written + // and entry computation parameters. // // - owning_buffers contains owning pointers to the buffers that were // allocated by this routine. This routine allocates buffers for temporary @@ -111,22 +112,21 @@ class CpuExecutable : public Executable { // result. StatusOr, std::vector>> - CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal, - tensorflow::gtl::ArraySlice arguments); + CreateBufferTable(DeviceMemoryAllocator* memory_allocator, int device_ordinal, + absl::Span arguments); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile); + Status ExecuteComputeFunction(const ExecutableRunOptions* run_options, + absl::Span buffers, + HloExecutionProfile* hlo_execution_profile); // Creates a ScopedShapedBuffer for holding the result of the computation, // moving buffers out of allocated_buffers and into the result as appropriate. // The addresses are set according to buffer assignment. StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::MutableArraySlice buffers); + absl::Span buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7bd4741a04b1135d9780e0cf765b7b33378526e1..7fbe0fa157c57eb0c274662a1de95cf5328ccfa8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr CpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "CPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2924b6365943f0a3ec998d7a77767a76cbb576ae..6af724b2a5d71b9c30f3485ffb7e51d1d201cb6b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface { CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "cpu_hlo_support_checker"; - } + absl::string_view name() const override { return "cpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index b40d264c03aba6e9308e8a621ae86e180e33c335..f9cd61bea3dc86cadff99d4a90eca44c16520823 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -35,7 +35,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kSlice || @@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (!CanBeLoopFused(*producer)) { - VLOG(2) << "Producer is not fusile."; + VLOG(2) << "Producer is not fusible."; return false; } @@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (CanBeLoopFused(*consumer)) { - VLOG(2) << "Fusing: consumer is elementwise or fusile."; + VLOG(2) << "Fusing: consumer is elementwise or fusible."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 991b14f17dbc8cd061af98e032824d3f7075e78b..284929ca073ca0d8c5c7cc383f8341a53d0f9e88 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -566,7 +567,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, MessOfFusileNodes) { +TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -697,8 +698,9 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); if (add_extra_use_for_dot) { + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); builder.AddInstruction( - HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + HloInstruction::CreateOutfeed(dot_shape, dot, token, "no_config")); } module->AddEntryComputation(builder.Build()); @@ -772,8 +774,8 @@ class GatherLoopFusionTest TEST_P(GatherLoopFusionTest, GatherLoopFusion) { const GatherLoopFusionTestSpec& spec = GetParam(); - string hlo_string = tensorflow::strings::StrCat( - "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", + spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_string)); @@ -791,11 +793,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -807,11 +809,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -823,11 +825,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -839,11 +841,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -855,11 +857,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -871,11 +873,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -887,11 +889,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index aa872d5ec9e7593b8d2f731421c17af590729529..bfecbd6e017893e4f6d3dcbc01d46c899e6060fa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -34,8 +34,8 @@ namespace cpu { // instruction stream. namespace { -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; +using absl::nullopt; +using absl::optional; using ShouldMakeOperandColMajorCache = tensorflow::gtl::FlatMap; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 3681d12d8da818d06d2f690024008c9ccb896286..9363af3b8941c68284915d6770188bde4c87f78e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 3ed7876715f64191f6e652d2b5cb1673df9a1b94..b8ace5702688096822573c7afae234cbcbe77b28 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -45,17 +46,16 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } -tensorflow::gtl::optional LlvmIrGemvTilingFactor( - const HloModuleConfig& config) { +absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrDotTilingFactor); int64 tiling_factor; if (it != extra_options_map.end() && - tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + absl::SimpleAtoi(it->second, &tiling_factor)) { return tiling_factor; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { @@ -64,38 +64,37 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } -static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, - tensorflow::StringPiece suffix) { +static absl::string_view RemoveSuffix(absl::string_view str, + absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); return str.substr(0, str.size() - suffix.size()); } -tensorflow::gtl::optional> LlvmIrGemmTileSize( +absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrGemmTileSize); if (it == extra_options_map.end()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - std::vector tile_components = - tensorflow::str_util::Split(it->second, ':'); + std::vector tile_components = absl::StrSplit(it->second, ':'); CHECK_EQ(tile_components.size(), 3); int64 tile_size_m; int64 tile_size_k; int64 tile_size_n_in_vector_width; - CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); - CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m)); + CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k)); - tensorflow::StringPiece tile_size_n_in_vector_width_str = + absl::string_view tile_size_n_in_vector_width_str = RemoveSuffix(tile_components[2], "*vectwidth"); - CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, - &tile_size_n_in_vector_width)); + CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); return std::tuple(tile_size_m, tile_size_k, tile_size_n_in_vector_width); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 429b9e16cbdd6f623919533582481f1640118081..47c7eb13b6e4cc05a23f82b8d2a25249f4b82ac0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,9 +27,8 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); -tensorflow::gtl::optional LlvmIrGemvTilingFactor( - const HloModuleConfig& config); -tensorflow::gtl::optional> LlvmIrGemmTileSize( +absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); +absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); } // namespace options diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 639064040f521a9e84bd87c5d05f674204e4d6e2..8a44c384bb0fe6f132c352ca8bd78baa23d093d4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 2ac950e6d93ade315808f2ca1d0bdd7bc85f53b9..1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -46,7 +46,7 @@ std::unique_ptr> MaybeTransposeArray2D(const Array2D& array, if (transpose) { std::swap(output_width, output_height); } - auto output = MakeUnique>(output_height, output_width); + auto output = absl::make_unique>(output_height, output_width); for (int y = 0; y < array.height(); y++) { for (int x = 0; x < array.width(); x++) { if (transpose) { @@ -93,7 +93,7 @@ std::unique_ptr> EigenMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it. Swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_EigenSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), @@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -204,7 +204,7 @@ std::unique_ptr> MKLMatrixMultiply(const Array2D& a, // Since we're going to transpose c before returning it, swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique>(n, m); + auto c_transpose = absl::make_unique>(n, m); if (single_threaded) { __xla_cpu_runtime_MKLSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 59bc7e0e16fcc66a010408259a1ccfb2b6bb35fd..5519a43b2f6bc3a7df9a58823e43fae42f7f94df 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -103,7 +104,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( if (ShapeUtil::IsNestedTuple(shape)) { return Unimplemented( "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + ShapeUtil::HumanString(literal.shape())); } // For a tuple, we transfer each of its elements to the device and @@ -151,11 +152,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Infeed shape must have positive size; got %lld", + return InvalidArgument("Infeed shape must have positive size; got %d", size); } @@ -178,7 +179,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( int64 size = GetByteSizeRequirement(literal_shape); // Note: OSS build didn't like implicit conversion from // literal_shape.dimensions() to the array slice on 2017-07-10. - tensorflow::gtl::ArraySlice dimensions( + absl::Span dimensions( tensorflow::bit_cast(literal_shape.dimensions().data()), literal_shape.dimensions().size()); TF_ASSIGN_OR_RETURN( @@ -224,7 +225,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data) { + absl::Span> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } @@ -237,18 +238,17 @@ StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple) { + absl::Span> buffer_data, bool is_tuple) { std::vector> buffers; for (auto b : buffer_data) { int64 size = b.second; if (size > std::numeric_limits::max()) { - return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + return InvalidArgument("Outfeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %lld", + return InvalidArgument("Outfeed shape must have positive size; got %d", size); } @@ -256,7 +256,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " << size_32 << "B"; - buffers.emplace_back(MakeUnique(b.first, size_32)); + buffers.emplace_back(absl::make_unique(b.first, size_32)); } std::vector buffer_pointers; @@ -283,7 +283,7 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( } // namespace xla static std::unique_ptr CreateCpuTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 80ef953d532798281c10b7a212b9c4d84a790c27..361d4b9c8422fff6afe53e56e0bb10a484c9becc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -56,7 +56,7 @@ class CpuTransferManager : public GenericTransferManager { // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data); + absl::Span> buffer_data); // Helper that transfers an array buffer from the device's outfeed. StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, @@ -68,12 +68,11 @@ class CpuTransferManager : public GenericTransferManager { // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( se::StreamExecutor* executor, - tensorflow::gtl::ArraySlice> buffer_data, - bool is_tuple); + absl::Span> buffer_data, bool is_tuple); TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index e4c674e227ffc6725ca929f720b9aa7cf7c4c032..3ae64142cd7e32d3aa8d50870efaf94698c06440 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "llvm/MC/MCInst.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -151,7 +151,7 @@ StatusOr Disassembler::DisassembleObjectFile( size = 1; } - ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + ostream << absl::StrFormat("0x%08lx", index) << " "; if (decode_status == llvm::MCDisassembler::Success) { // For branches, try to determine the actual address and emit it as an @@ -163,7 +163,7 @@ StatusOr Disassembler::DisassembleObjectFile( uint64_t target; if (inst_analysis_->evaluateBranch( instruction, section_address + index, size, target)) { - annotation = tensorflow::strings::Printf("[0x%08lx]", target); + annotation = absl::StrFormat("[0x%08lx]", target); } } inst_printer_->printInst(&instruction, ostream, annotation.c_str(), diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index f2ac742b6e6fc12076e7a2a242155c005f4b05b8..99fa707c959854e50c6d954fe92b87e93e267dc6 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -79,7 +80,7 @@ class MemoryTile { // `minor_dim_offset`}. // // Note: `major_dim_offset` is a parameter to the constructor. - void StoreTile(tensorflow::gtl::ArraySlice tile, + void StoreTile(absl::Span tile, llvm::Value* minor_dim_offset) const { CHECK_EQ(tile.size(), pointers_.size()); for (int64 i = 0; i < pointers_.size(); i++) { @@ -146,9 +147,9 @@ class GemvConfig { bool has_addend() const { return has_addend_; } string GetCacheKey() const { - return tensorflow::strings::StrCat( - name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", - tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); } protected: @@ -621,19 +622,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // This class implements a tiled matrix multiplication algorithm, intended for -// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, -// Kazushige, and Robert Van De Geijn. "High-performance implementation of the -// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): -// 4). +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". // // This only supports canonical dot operations (i.e. where the lhs contraction // dimension is 1 and the rhs contraction dimension is 0) over row major // matrices. -class MatrixMatrixBlockPanelEmitter { +class TiledSmallGemmEmitter { public: - // Describe the dimensions of the GEBP kernel. These will usually not be the - // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP - // kernels with smaller dimensions. + // Describe the dimensions of the kernel. class Dimensions { public: explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} @@ -642,9 +643,7 @@ class MatrixMatrixBlockPanelEmitter { int64 k() const { return k_; } int64 n() const { return n_; } - string ToString() const { - return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); - } + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } private: const int64 m_; @@ -652,9 +651,9 @@ class MatrixMatrixBlockPanelEmitter { const int64 n_; }; - // Represents the configuration of the GEBP emitter. The LLVM IR emitted by - // the emitter, modulo the LLVM values holding the input and output buffers, - // must be a function of the instance of `Config` passed to it. + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. // // `dims` holds the matrix multiplication dimensions. // @@ -687,10 +686,10 @@ class MatrixMatrixBlockPanelEmitter { tile_size_k_(tile_size_k) {} string GetCacheKey() const { - return tensorflow::strings::StrCat( - "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), - "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - tile_size_m(), "_", tile_size_k()); + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } @@ -712,11 +711,11 @@ class MatrixMatrixBlockPanelEmitter { int64 tile_size_k_; }; - // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies // `lhs` with `rhs` and stores the result in `result`. - explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) : lhs_(lhs), rhs_(rhs), result_(result), @@ -780,9 +779,9 @@ class MatrixMatrixBlockPanelEmitter { KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { +void TiledSmallGemmEmitter::HandleResiduesOnN() { // We can only iterate the `n` dimension for an extent that is divisible by // the vectorization width. So we emit an outer loop that first processes the // largest extent in `n` that is divisible by max_vectorization_width, then @@ -799,7 +798,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gebp"); + "gemm"); HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } @@ -813,7 +812,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); @@ -821,9 +820,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { int64 k_start = 0; int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { @@ -838,7 +837,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( +void TiledSmallGemmEmitter::HandleResiduesOnM( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { const int64 m_end = dims().m() - dims().m() % tile_size_m(); @@ -921,7 +920,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( // +-------------------+-------------------+-------------------+--------- // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... // +-------------------+-------------------+-------------------+--------- -void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( +void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { @@ -1001,12 +1000,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } -bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( +bool DotOpEmitter::EmitSmallGemmIfProfitable( const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + if (ShouldUseMultiThreadedEigen()) { return false; } + if (!EnableExperimentalLlvmIrGemm()) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = mat_mult_dims.k <= 128 && + ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || + (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); + if (!small_gemm) { + return false; + } + } + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { return false; } @@ -1054,15 +1063,15 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - MatrixMatrixBlockPanelEmitter::Config config( + TiledSmallGemmEmitter::Config config( /*scalar_type=*/primitive_type, - MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, /*max_vectorization_width=*/max_target_vector_width, /*max_vector_count=*/tile_size_n_in_vector_width, /*min_vectorization_width=*/std::min(4, max_target_vector_width), /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " << config.GetCacheKey(); const bool enable_fast_math = @@ -1075,10 +1084,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, rhs, target, [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - gebp_emitter.Emit(); + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/target, b_); + small_gemm_emitter.Emit(); }); return true; @@ -1136,7 +1145,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); + return EmitSmallGemmIfProfitable(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -1458,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() { break; default: return Unimplemented("Invalid type %s for dot operation", - PrimitiveType_Name(type).c_str()); + PrimitiveType_Name(type)); } llvm::Type* float_ptr_type = float_type->getPointerTo(); @@ -1610,7 +1619,7 @@ bool PotentiallyImplementedAsEigenDot( // For vector-matrix dot products, it is always profitable to make the Rhs // column major. -tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( +absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 590032fbe907d7ca90bf69b7ccc3170b8efec72e..4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot( // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot // or a fusion containing a dot. -tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( +absl::optional ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation @@ -121,7 +121,7 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; - bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index db54454707983ade31594119b2e868fa168d4cc2..c8312d80bd5012e5bcb42a410db18a7fa77a2eb6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -30,15 +30,16 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); - rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b_->getFloatTy()); + rhs = FPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -58,21 +59,21 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = Call(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } -StatusOr CpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { bool cast_result_to_fp16 = false; string function_name; switch (prim_type) { case F16: cast_result_to_fp16 = true; - value = b_->CreateFPCast(value, b_->getFloatTy()); + value = FPCast(value, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -91,16 +92,16 @@ StatusOr CpuElementalIrEmitter::EmitTanh( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, value); + llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { if (hlo->opcode() == HloOpcode::kMap) { return [this, hlo, &operand_to_generator]( const llvm_ir::IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 76833e765d05f2477961cd06cead66797c5be623..e3fba9306b72904803259047fafea245a8e183db 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6f433b4f30372da9cf4503396dbb60172cfc0cb0..e5cf15c686157d837901fa912bdde2a7a5d501d9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -27,6 +27,9 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -64,11 +67,8 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -100,6 +100,11 @@ IrEmitter::IrEmitter( b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_cpu_enable_fast_math())); + Status s = GatherComputationsByAllocationType( + &hlo_module, &thread_local_computations_, &global_computations_); + absl::c_sort(thread_local_computations_); + absl::c_sort(global_computations_); + TF_CHECK_OK(s) << "Should have failed buffer assignment."; } StatusOr IrEmitter::EmitComputation( @@ -170,9 +175,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + BitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -230,9 +235,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } - return Unimplemented( - "unsupported operand type %s for copy instruction", - PrimitiveType_Name(copy->shape().element_type()).c_str()); + return Unimplemented("unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type())); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -338,10 +342,10 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Write the tuple index table. TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, assignment_.GetUniqueSlice(infeed, {0})); - llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice, assignment_.GetUniqueSlice(infeed, {1})); - llvm::Value* token_address = EmitTempBufferPointer( + llvm::Value* token_address = EmitBufferPointer( token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1)); llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address}, &b_, module_); @@ -364,9 +368,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* instruction) { // Only the outer tuple buffer's target address is obtained from // GetEmittedValueFor, to handle the case when Infeed is the root // instruction. Target addresses for internal elements can be obtained - // from EmitTempBufferPointer. + // from EmitBufferPointer. llvm::Value* tuple_element_address = - EmitTempBufferPointer(buffer, tuple_element_shape); + EmitBufferPointer(buffer, tuple_element_shape); TF_RETURN_IF_ERROR(EmitXfeedTransfer( XfeedKind::kInfeed, tuple_element_shape, tuple_element_address)); @@ -389,7 +393,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int64 length = ByteSizeOf(shape); if (length <= 0 || length > std::numeric_limits::max()) { return InvalidArgument( - "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", length); } @@ -440,27 +444,33 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = b_.CreateCall( - acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = + Call(acquire_func, + {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, - /*SrcAlign=*/1, length_32); + MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, - /*SrcAlign=*/1, length_32); + MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, - shape_ptr, b_.getInt32(shape_length)}); + Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr, + b_.getInt32(shape_length)}); return Status::OK(); } Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { + // Outfeed produces no useful result, but it does return a token[] that can be + // threaded through to other side effecting operations to ensure ordering. In + // the IR emitter we treat this token as a normal u8[] and thus need to insert + // an entry for it in emitted_value_. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(outfeed)); + HloInstruction* operand = outfeed->operands()[0]; const Shape& operand_shape = operand->shape(); @@ -501,8 +511,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name) { + absl::Span elemental_operands, absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -519,8 +528,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); - b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); + Store(Load(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); std::vector window_size; @@ -537,22 +546,21 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = - b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to // input_index[i] < bound as an *unsigned* comparison, since a negative // value will wrap to a large positive value. - llvm::Value* index_condition = b_.CreateICmpULT( - input_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + llvm::Value* index_condition = + ICmpULT(input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; } else { - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } } CHECK(in_bounds_condition != nullptr); @@ -565,12 +573,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); - b_.CreateStore(result, accumulator_address); + *reduce_window->to_apply(), {Load(accumulator_address), input_value}, + "reducer_function"); + Store(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_address); + return Load(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { @@ -647,7 +655,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return b_.CreateLoad(init_value_addr); + return Load(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. @@ -667,7 +675,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); @@ -685,15 +693,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( - source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( - operand_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + llvm::Value* strided_index = + NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = + ICmpULT(operand_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -703,7 +710,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -712,38 +719,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* operand_element = Load(operand_address); llvm::Value* result = EmitThreadLocalCall( *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, - "select_function"); + {Load(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -754,8 +760,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = @@ -837,7 +843,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( lhs_llvm_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); - b_.CreateStore(constant_zero, sum_address); + Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); std::vector kernel_spatial(num_spatial_dims); @@ -846,7 +852,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( loops .AddLoop( 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) + absl::StrCat("k", i)) ->GetIndVarValue(); } llvm::Value* input_feature = @@ -864,11 +870,11 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::Value* kernel_index, const WindowDimension& window_dim) { llvm::Value* strided_index = - b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = b_.CreateNSWMul( - kernel_index, b_.getInt64(window_dim.window_dilation())); - return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); + NSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -885,9 +891,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( // Also need to check that the input coordinates are not in one of the // holes created by base dilation. const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = - b_.CreateSRem(input_index, b_.getInt64(base_dilation)); - return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); + return ICmpEQ(remainder, b_.getInt64(0)); }; llvm::Value* in_bounds_condition = b_.getInt1(true); @@ -895,17 +900,17 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); llvm::Value* dim_not_in_hole = not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual // data indices on the lhs. const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + return SDiv(input_index, b_.getInt64(base_dilation)); }; for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = @@ -930,8 +935,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() - ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) + ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) : kernel_spatial[i]; } @@ -940,13 +945,13 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = - b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); - llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); - b_.CreateStore(sum, sum_address); + FMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = FAdd(Load(sum_address), product); + Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(sum_address); + return Load(sum_address); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1072,34 +1077,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - b_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), - b_.CreateBitCast(lhs_address, ir_ptr_type), - b_.CreateBitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + Call(conv_func, { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1159,15 +1162,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - b_.CreateCall( - fft_func, - {GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - b_.CreateBitCast(operand_address, int8_ptr_type), - b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), - b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + Call(fft_func, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1203,11 +1205,11 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { const Shape& operand_shape = crs->operand(i)->shape(); CHECK(ShapeUtil::IsArray(operand_shape)) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); + operand_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); + MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); @@ -1457,7 +1459,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1466,19 +1468,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { initial_value = - b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); + VectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); + AlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), @@ -1500,24 +1502,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = b_.CreateBitCast( + llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - b_.CreateBitCast(input_address, accumulator[i]->getType()); + BitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - b_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); + AlignedLoad(accumulator[i], element_alignment); + auto addend = AlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = reduction_generator(&b_, current_accumulator_value, addend); - b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); + AlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), - input_address_typed, 1); + input_address = ConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } @@ -1526,8 +1528,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { - result_ssa.push_back( - b_.CreateAlignedLoad(accumulator_shard, element_alignment)); + result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1536,25 +1537,25 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = b_.CreateBitCast( - store_address, - llvm::PointerType::getUnqual(value_to_store[i]->getType())); + auto store_address_typed = + BitCast(store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = b_.CreateAlignedStore( - value_to_store[i], store_address_typed, alignment); + auto store_instruction = + AlignedStore(value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = b_.CreateConstInBoundsGEP1_32( - value_to_store[i]->getType(), store_address_typed, 1); + store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), + store_address_typed, 1); } } } StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - gtl::ArraySlice dimensions, HloComputation* function, + absl::Span dimensions, HloComputation* function, string* failure_reason) { if (!ReductionPreservesLayout(*reduce)) { return false; @@ -1620,9 +1621,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); - std::unique_ptr loop = - loop_nest.AddLoop(start_index, end_index, - tensorflow::strings::Printf("dim.%lld", dimension)); + std::unique_ptr loop = loop_nest.AddLoop( + start_index, end_index, absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -1641,9 +1641,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 start_index = 0; int64 end_index = (innermost_dimension_size / vectorization_factor) * vectorization_factor; - std::unique_ptr loop = loop_nest.AddLoop( - start_index, end_index, vectorization_factor, - tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + std::unique_ptr loop = + loop_nest.AddLoop(start_index, end_index, vectorization_factor, + absl::StrFormat("dim.%d", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); @@ -1705,7 +1705,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) { const HloInstruction* arg = reduce->mutable_operand(0); const HloInstruction* init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); @@ -1713,8 +1713,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); - b_.CreateStore(load_init_value, accumulator_addr); + llvm::Value* load_init_value = Load(init_value_addr); + Store(load_init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1747,12 +1747,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // Apply the reduction function to the loaded value. llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); - b_.CreateStore(result, accumulator_addr); + Store(result, accumulator_addr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); } Status IrEmitter::HandleReduce(HloInstruction* reduce) { @@ -1762,7 +1762,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -1990,7 +1990,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return b_.CreateLoad(padding_value_addr); + return Load(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output @@ -2012,10 +2012,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = b_.CreateMul( - operand_index[i], - b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); - llvm::Value* index = b_.CreateAdd( + llvm::Value* offset = + Mul(operand_index[i], + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } @@ -2102,7 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { {}, &b_, computation->name(), /*return_value_buffer=*/emitted_value_[call], /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), /*profile_counters_arg=*/GetProfileCountersArgument()); HloInstruction* root = computation->root_instruction(); @@ -2117,8 +2117,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - gtl::ArraySlice operands(custom_call->operands()); - tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); + absl::Span operands(custom_call->operands()); + absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2126,10 +2126,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + PointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = - b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); - b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + InBoundsGEP(operands_alloca, {b_.getInt64(i)}); + Store(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( @@ -2141,9 +2141,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); auto* output_address_arg = - b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); + PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + Call(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2170,8 +2170,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return InternalError( "instruction %s %s does not share slice with " "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), + slice_b.ToString()); } return Status::OK(); }; @@ -2202,15 +2202,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - b_.CreateBr(header_bb); + Br(header_bb); b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); - llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + llvm::Value* while_predicate = ICmpNE( + Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2219,7 +2218,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - b_.CreateCondBr(while_predicate, body_bb, exit_bb); + CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. b_.SetInsertPoint(body_bb); @@ -2228,7 +2227,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); // Finishes with a branch back to the header. - b_.CreateBr(header_bb); + Br(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); @@ -2238,7 +2237,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, gtl::ArraySlice operands, + HloInstruction* concatenate, absl::Span operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2275,7 +2274,6 @@ StatusOr IrEmitter::EmitFastConcatenate( output_min2maj.end()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); - llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); @@ -2298,9 +2296,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = - b_.CreateBitCast(target_array.EmitArrayElementAddress( - outer_dims_index, &b_, "target_region"), - i8_ptr_type); + BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, + "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2314,13 +2312,12 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = b_.CreateBitCast( + llvm::Value* copy_source_address = BitCast( source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = - b_.CreateGEP(i8_type, target_region_begin, - b_.getInt64(byte_offset_into_target_region)); + GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2352,15 +2349,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = b_.CreateAlignedLoad( - b_.CreateBitCast(source, primitive_ptr_type), element_alignment); + auto* load_instruction = + AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = b_.CreateAlignedStore( - load_instruction, b_.CreateBitCast(target, primitive_ptr_type), - element_alignment); + auto* store_instruction = + AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), + element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = b_.CreateMemCpy( + auto* memcpy_instruction = MemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2376,7 +2373,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - gtl::ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2422,9 +2419,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = b_.CreateLoad( - GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = b_.CreateICmpNE( + llvm::LoadInst* pred_value = + Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); @@ -2450,11 +2447,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { @@ -2511,8 +2503,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return b_.CreateGEP(GetProfileCountersArgument(), - b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); + return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), + AsStringRef(counter_name)); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2630,15 +2622,15 @@ llvm::Value* IrEmitter::GetProfileCountersArgument() { return compute_function_->profile_counters_arg(); } -llvm::Value* IrEmitter::GetTempBuffersArgument() { - return compute_function_->temp_buffers_arg(); +llvm::Value* IrEmitter::GetBufferTableArgument() { + return compute_function_->buffer_table_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return compute_function_->exec_run_options_arg(); } -llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( +llvm::Value* IrEmitter::EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address = [&]() -> llvm::Value* { @@ -2666,8 +2658,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); + llvm::LoadInst* param_address_untyped = Load(param_address_offset); if (!ShapeUtil::IsOpaque(target_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); @@ -2687,25 +2678,23 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( - IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, - MinimumAlignmentForShape(target_shape)); + IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), + &b_, MinimumAlignmentForShape(target_shape)); auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); CHECK(it_inserted_pair.second); buf_it = it_inserted_pair.first; } return buf_it->second; }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( +llvm::Value* IrEmitter::EmitGlobalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape) { const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( - GetTempBuffersArgument(), slice.index(), &b_); - llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + GetBufferTableArgument(), slice.index(), &b_); + llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { tempbuf_address_base->setMetadata( @@ -2719,20 +2708,20 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( if (slice.offset() > 0) { // Adjust the address to account for the slice offset. tempbuf_address_untyped = - b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); + InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } -llvm::Value* IrEmitter::EmitTempBufferPointer( - const BufferAllocation::Slice& slice, const Shape& target_shape) { +llvm::Value* IrEmitter::EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape) { if (slice.allocation()->is_thread_local()) { - return EmitThreadLocalTempBufferPointer(slice, target_shape); + return EmitThreadLocalBufferPointer(slice, target_shape); } else if (slice.allocation()->is_constant()) { return FindOrDie(constant_buffer_to_global_, slice.allocation()->index()); } else { - return EmitGlobalTempBufferPointer(slice, target_shape); + return EmitGlobalBufferPointer(slice, target_shape); } } @@ -2740,7 +2729,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { const Shape& target_shape = op->shape(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, assignment_.GetUniqueTopLevelSlice(op)); - llvm::Value* addr = EmitTempBufferPointer(slice, target_shape); + llvm::Value* addr = EmitBufferPointer(slice, target_shape); addr->setName(AsStringRef(IrName(op))); emitted_value_[op] = addr; return Status::OK(); @@ -2753,7 +2742,7 @@ Status IrEmitter::EmitTargetElementLoop( } Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -2769,8 +2758,7 @@ Status IrEmitter::EmitTargetElementLoop( TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, assignment_.GetUniqueSlice(target_op, {i})); const Shape& element_shape = ShapeUtil::GetSubshape(target_shape, {i}); - llvm::Value* op_target_address = - EmitTempBufferPointer(slice, element_shape); + llvm::Value* op_target_address = EmitBufferPointer(slice, element_shape); output_arrays.push_back( llvm_ir::IrArray(op_target_address, element_shape)); } @@ -2808,15 +2796,15 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + MemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - gtl::ArraySlice operands, - gtl::ArraySlice supported_types) { + absl::Span operands, + absl::Span supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); @@ -2827,8 +2815,8 @@ Status IrEmitter::ElementTypesSameAndSupported( if (std::find(supported_types.begin(), supported_types.end(), primitive_type) == supported_types.end()) { return Unimplemented("unsupported operand type %s in op %s", - PrimitiveType_Name(primitive_type).c_str(), - HloOpcodeString(instruction.opcode()).c_str()); + PrimitiveType_Name(primitive_type), + HloOpcodeString(instruction.opcode())); } return Status::OK(); } @@ -2846,9 +2834,10 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { } llvm::Value* IrEmitter::EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name) { + const HloComputation& callee, absl::Span parameters, + absl::string_view name) { + CHECK(absl::c_binary_search(thread_local_computations_, &callee)); + const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2863,38 +2852,39 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( CHECK(!parameter->getType()->isPointerTy()); llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); + Store(parameter, parameter_addr); parameter_addrs.push_back(parameter_addr); } llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); - return b_.CreateLoad(return_value_buffer); + return Load(return_value_buffer); } void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + absl::string_view name) { + CHECK(absl::c_binary_search(global_computations_, &callee)); + + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*buffer_table_arg=*/GetBufferTableArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( @@ -2906,7 +2896,7 @@ llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( const BufferAllocation::Slice root_buffer = assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie(); - return EmitTempBufferPointer(root_buffer, root_inst->shape()); + return EmitBufferPointer(root_buffer, root_inst->shape()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c9a1dab62dcbcd926baa82737d24efa03fd326e9..58a333b8fb2dc46868b04fec0d7d87788a809d06 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -39,13 +41,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,13 +56,14 @@ namespace cpu { // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: // Create a new LLVM IR emitter. // // hlo_module: the HLO module we are emitting IR for. - // assignment: a BufferAssignment from which we know which temporary buffers - // are used by the HLO nodes. + // assignment: a BufferAssignment from which we know which buffers are used by + // the HLO nodes. // llvm_module: the LLVM module to emit IR into. // instruction_to_profile_idx: the mapping from HLO instructions to their // index in the profiling array. @@ -100,14 +102,17 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return &b_; } + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); // Emit code to map one element according to `map_instr`. llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, - tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name); + absl::Span elemental_operands, + absl::string_view name); protected: // @@ -152,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -215,31 +219,28 @@ class IrEmitter : public DfsHloVisitorWithDefault { // argument of the computation function being emitted by this emitter. llvm::Value* GetExecutableRunOptionsArgument(); - // Get the llvm::Value* that represents the "temps" argument of the + // Get the llvm::Value* that represents the "buffer_table" argument of the // computation function being emitted by this emitter. - llvm::Value* GetTempBuffersArgument(); + llvm::Value* GetBufferTableArgument(); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + // Helper for EmitBufferPointer. + llvm::Value* EmitGlobalBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); - // Helper for EmitTempBufferPointer. - llvm::Value* EmitThreadLocalTempBufferPointer( + // Helper for EmitBufferPointer. + llvm::Value* EmitThreadLocalBufferPointer( const BufferAllocation::Slice& slice, const Shape& target_shape); // Emits code that computes the address of the given buffer allocation slice. - // - // TODO(sanjoy): This should be renamed to reflect that it no longer provides - // access to just temporaries. - llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice, - const Shape& target_shape); + llvm::Value* EmitBufferPointer(const BufferAllocation::Slice& slice, + const Shape& target_shape); // Emits a function into the current module. This can be used for // computations embedded inside other computations, such as the // function that a map operation applies. StatusOr EmitFunction( HloComputation* function, // The function to emit. - tensorflow::StringPiece + absl::string_view function_name_suffix); // Used for LLVM IR register names. // Emits a call to a thread local function (e.g. to the computation nested @@ -248,17 +249,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // `parameters` holds the *scalar values* that need to be passed to the // callee. The return value is the scalar returned by the callee. - llvm::Value* EmitThreadLocalCall( - const HloComputation& callee, - tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name); + llvm::Value* EmitThreadLocalCall(const HloComputation& callee, + absl::Span parameters, + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to // the parameters and return values for these computations so there is no need // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); + void EmitGlobalCall(const HloComputation& callee, absl::string_view name); // Returns the buffer to which a global call to `callee` would have written // its result. @@ -268,8 +267,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // match and are of one of the given supported types. Status ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types); + absl::Span operands, + absl::Span supported_types); // Emit IR to perform a computation for every element in the given target op. // This produces a series of nested loops (one for each dimension of the op's @@ -285,7 +284,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); Status EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the @@ -316,10 +315,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // concepts that generalize over other vectorizable operations. We should // consider pulling out these abstractions into a VectorizingIrEmitter or // something similar. - StatusOr EmitVectorizedReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, - string* failure_reason); + StatusOr EmitVectorizedReduce(HloInstruction* reduce, + HloInstruction* arg, + HloInstruction* init_value, + absl::Span dimensions, + HloComputation* function, + string* failure_reason); // We'd like to keep one or two one cache-line's worth of data in registers // without generating IR with illegal (e.g. excessively large or @@ -369,16 +370,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, absl::Span dimensions, unsigned element_alignment); // Tries to emit a fast concatenate operation using memcpy. Returns true if // successful, and false on failure. On failure, sets "failure_reason" to a // string describing why it could not emit a fast concatenate. - StatusOr EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, - string* failure_reason); + StatusOr EmitFastConcatenate(HloInstruction* concatenate, + absl::Span operands, + string* failure_reason); // Emits LLVM IR to transfer "element_count" elements of type "primitive_type" // from the address "source" to the address "target". @@ -387,8 +387,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Assignment of the temporary buffers needed by the computation and their - // shape information. + // Assignment of the buffers needed by the computation and their shape + // information. const BufferAssignment& assignment_; // The LLVM module into which IR will be emitted. @@ -568,6 +568,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap constant_buffer_to_global_; + std::vector thread_local_computations_; + std::vector global_computations_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5b149969c88fb4325ca28aa11dc3708..adfb8392bf6fa356f0a5cdab3ff74036eca8918e 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -77,19 +78,20 @@ void IrFunction::Initialize(const string& function_name, const bool optimize_for_size_requested, const bool enable_fast_math) { // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // void function(i8* retval, i8* run_options, i8** params, i8** + // buffer_table, // i64* dynamic_loop_bounds, i64* prof_counters) // // For thread local functions: // retval: points to the returned value. // params: address of an array with pointers to parameters. - // temps: is null + // buffer_table: is null // // For global functions: // retval: is null // params: is null - // temps: address of an array with pointers to temporary buffers and entry - // computation parameters. + // buffer_table: address of an array with pointers to temporary buffers and + // entry computation parameters (but not to constant buffers). // // Therefore, the generated function's signature (FunctionType) is statically // determined - parameter unpacking is done in code generated into the @@ -115,7 +117,7 @@ void IrFunction::Initialize(const string& function_name, // \---------/ \---------/ \-----------/ // // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 | // | addr | addr | | addr | // \---------------------------------------------/ // | | | @@ -133,9 +135,9 @@ void IrFunction::Initialize(const string& function_name, // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // \---------------------------------------------/ - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. + // Even though the type of params and buffer_table is void** in the host's + // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to + // the code to use GEPs to unravel the indirection layers. llvm::FunctionType* function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), /*Params=*/ @@ -159,8 +161,8 @@ void IrFunction::Initialize(const string& function_name, exec_run_options_arg_ = &*arg_iter; (++arg_iter)->setName("params"); parameters_arg_ = &*arg_iter; - (++arg_iter)->setName("temps"); - temp_buffers_arg_ = &*arg_iter; + (++arg_iter)->setName("buffer_table"); + buffer_table_arg_ = &*arg_iter; if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); dynamic_loop_bounds_arg_ = &*arg_iter; @@ -189,7 +191,7 @@ void IrFunction::Initialize(const string& function_name, llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), b_->getInt64(offset), AsStringRef(name))); } @@ -199,10 +201,10 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // Returns an array of compute function call arguments (including parameter // address buffer). std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; if (parameter_addresses.empty()) { @@ -211,13 +213,13 @@ std::vector GetArrayFunctionCallArguments( } else { parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); + AsStringRef(absl::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -229,7 +231,7 @@ std::vector GetArrayFunctionCallArguments( }; std::vector arguments{ to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), - parameter_addresses_buffer, temp_buffers_arg}; + parameter_addresses_buffer, buffer_table_arg}; if (profile_counters_arg != nullptr) { arguments.push_back(profile_counters_arg); } @@ -320,8 +322,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index a41cbb64cdd9f5b6de5d1eadfbf7e63e1e984801..623a5f185fa1fd0526bc8664e2ba11c9dde79b1d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -80,8 +80,9 @@ class IrFunction { // Get the llvm::Value* that represents this functions parameters argument. llvm::Value* parameters_arg() { return parameters_arg_; } - // Get the llvm::Value* that represents this functions "temps" argument. - llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + // Get the llvm::Value* that represents this functions "buffer_table" + // argument. + llvm::Value* buffer_table_arg() { return buffer_table_arg_; } // Get the llvm::Value* that represents this functions "prof_counters" // argument. @@ -108,17 +109,17 @@ class IrFunction { llvm::Argument* result_arg_; llvm::Value* exec_run_options_arg_; llvm::Value* parameters_arg_; - llvm::Value* temp_buffers_arg_; + llvm::Value* buffer_table_arg_; llvm::Value* dynamic_loop_bounds_arg_ = nullptr; llvm::Value* profile_counters_arg_; }; // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, - llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, - llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + absl::Span parameter_addresses, llvm::IRBuilder<>* b, + absl::string_view name, llvm::Value* return_value_buffer, + llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg, + llvm::Value* profile_counters_arg); // Emits a call to a runtime fork/join function which dispatches parallel // calls to 'parallel_function' (and joins threads before returning). diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 8560e4296aa95fe791446abb1b4363b9145f343e..f8441c3e345504616485c6b34b4302acd5cc23a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace cpu { @@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { CHECK_NE(index_type, nullptr); CHECK(!ShapeUtil::IsTuple(shape_)); @@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); + /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, + end_index); array_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 076c683ca566f2c53992c358903d2aadead290f9..a604e1db222139c239a2a89359a7359463e0def7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4fa5984b0466b178a587e97cbced97deac749f74..b4c0c09ec06bac9b5e228428c072948afdd4a547 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -109,7 +111,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. - auto cost_analysis = MakeUnique(shape_size); + auto cost_analysis = absl::make_unique(shape_size); HloComputation* computation = module->entry_computation(); Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { @@ -216,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( // Outline 'instruction' in 'computation' for parallel task assignment. auto* call = module->OutlineExpressionFromComputation( - {instruction}, - tensorflow::strings::StrCat("parallel_", instruction->name()), + {instruction}, absl::StrCat("parallel_", instruction->name()), computation); // Set assigned dimension partitioning to 'instruction'. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 8becc8fa23424d7454cc783eb9d853aecb5d053b..a99cd99c14abb66fc426c43656520e01f34a1700 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface { target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cpu-parallel-task-assigner"; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 36c9f743859ae2da6c4fb3fd753bd7862fe2d3ab..a84ee78b19981e480858320e445de7f5dae27d61 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} @@ -110,9 +111,10 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { const string hlo_string = R"( HloModule TestTaskParallel_infeed_outfeed ENTRY InfeedOutfeed { - infeed0 = (u32[12345678,2]{1,0}, token[]) infeed() + token = token[] after-all() + infeed0 = (u32[12345678,2]{1,0}, token[]) infeed(token) infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0 - ROOT outfeed0 = token[] outfeed(infeed0.data) + ROOT outfeed0 = token[] outfeed(infeed0.data, token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc index a5f34908d70dd18ec017bdf9833c7df40f80db07..2d9492eacfea34bec3b0f1115e171a5328b7cdc3 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -61,7 +61,7 @@ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, // TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, uint64* prof_counters, int32 num_partitions, + void** buffer_table, uint64* prof_counters, int32 num_partitions, int64* partitions, int32 num_partitioned_dims, void* function_ptr) { VLOG(2) << "ParallelForkJoin ENTRY" << " num_partitions: " << num_partitions @@ -81,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( for (int32 i = 1; i < num_partitions; ++i) { const int64 offset = i * stride; run_options->intra_op_thread_pool()->enqueueNoNotification( - [i, function, result_ptr, run_options_ptr, temps, prof_counters, + [i, function, result_ptr, run_options_ptr, buffer_table, prof_counters, partitions, offset, &bc]() { - function(result_ptr, run_options_ptr, nullptr, temps, + function(result_ptr, run_options_ptr, nullptr, buffer_table, &partitions[offset], prof_counters); bc.DecrementCount(); VLOG(3) << "ParallelForkJoin partition " << i << " done."; @@ -91,7 +91,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin( } // Call first compute function inline. - function(result_ptr, run_options_ptr, params, temps, &partitions[0], + function(result_ptr, run_options_ptr, params, buffer_table, &partitions[0], prof_counters); VLOG(3) << "ParallelForkJoin partition 0 done."; bc.Wait(); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index 1cf0ec6e3df400e35fa4e755a0b25b4ce7966e8f..a279c7d2d61bdd138f5285a8c8ccc89d22db9692 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -24,7 +24,7 @@ extern "C" { // threads before returning. See comments in runtime_fork_join.cc for details. extern void __xla_cpu_runtime_ParallelForkJoin( void* result_ptr, const void* run_options_ptr, const void** params, - void** temps, tensorflow::uint64* prof_counters, + void** buffer_table, tensorflow::uint64* prof_counters, tensorflow::int32 num_partitions, tensorflow::int64* partitions, tensorflow::int32 num_partitioned_dims, void* function_ptr); diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index f227e4ae139b92e56786e38ef8eef72c9e2cd424..942e2ddd3940fffd5d87518f059beaced3cdc925 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -67,8 +67,8 @@ int main(int argc, char** argv) { /*execution_profile=*/&profile); std::unique_ptr actual = result.ConsumeValueOrDie(); - LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", - profile.compute_time_ns()); + LOG(INFO) << absl::StrFormat("computation took %dns", + profile.compute_time_ns()); LOG(INFO) << actual->ToString(); return 0; diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index ae80a6f4977f85cfd9f872734fd0a69432a1f382..7d8e51f909e3db699b745f94a6c625407bc4a6e3 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -102,22 +102,22 @@ TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { { ShapePartitionIterator iterator(shape, {1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 5}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2}); EXPECT_EQ(2, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(1))); } { ShapePartitionIterator iterator(shape, {3}); EXPECT_EQ(3, iterator.GetTotalPartitionCount()); - EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); - EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); - EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + EXPECT_TRUE(absl::c_equal(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(absl::c_equal(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(absl::c_equal(Partition({{2, 3}}), iterator.GetPartition(2))); } } @@ -128,20 +128,20 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { ShapePartitionIterator iterator(shape, {1, 1}); EXPECT_EQ(1, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); } { ShapePartitionIterator iterator(shape, {2, 2}); EXPECT_EQ(4, iterator.GetTotalPartitionCount()); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + absl::c_equal(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); EXPECT_TRUE( - ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + absl::c_equal(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + absl::c_equal(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); EXPECT_TRUE( - ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + absl::c_equal(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); } } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index be772cfb7e564cebc5725854dbf5678e5c507556..bf98064647f4c29ba689902da4d737e1922391d3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" @@ -170,15 +170,14 @@ namespace { bool RegisterKnownJITSymbols() { CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); -#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ - do { \ - auto* function_address = \ - reinterpret_cast(__xla_cpu_runtime_##base_name); \ - registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ - CHECK_EQ( \ - tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ - "__xla_cpu_runtime_" #base_name); \ +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ } while (false) REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 181cec3cdddeb40daf5276d9d1d6a139417a6072..2384166fd2002a67a8aa785ad5fb341d037ee01f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -51,6 +51,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -108,6 +110,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -121,6 +124,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 6fcce42eaa4599eb8a6dacc1bd39eefd39aa5e50..fcd87b36b32915773546c211d7d2c447a69bef49 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index d98856fdbf4165a5909f193ebe8512e21af83dfc..22721051e54e2cf9590b60333c51d1d028bb28e9 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -129,8 +129,8 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { error_spec_); } -TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { - // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { + // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 973aac8766f5aabca15e5173b43480c113c100dd..a434c04a980b9b3cd849792b97a0d9e965ba09f2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android"; struct IntrinsicTestSpec { HloOpcode opcode; - tensorflow::StringPiece triple; - tensorflow::StringPiece features; - tensorflow::StringPiece check_lines; + absl::string_view triple; + absl::string_view features; + absl::string_view check_lines; }; // Tests that unary functions get lowered using intrinsic calls. @@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", - features.c_str()); + return absl::StrCat(opcode, "_On_", triple, + (features.empty() ? "" : "_With"), features); } }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index 90b99c828e2fcfd77579026a39d3a6711599feee..3b87683ffffefd2aa24dd234cc072425bef00a24 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -38,7 +38,8 @@ while_body { while_cond { arg_cond = f32[2,3,2] parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -50,8 +51,9 @@ ENTRY main { {{2, 1}, {2001, 3002}, {2001, 2002}}}) const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body - out0 = token[] outfeed(f32[2,3,2] const_a) - ROOT out1 = token[] outfeed(f32[2,3,2] const_b) + token = token[] after-all() + out0 = token[] outfeed(f32[2,3,2] const_a, token[] token) + ROOT out1 = token[] outfeed(f32[2,3,2] const_b, token[] token) } )"; @@ -85,7 +87,8 @@ while_body { while_cond { arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) - infeed = (pred[], token[]) infeed() + token = token[] after-all() + infeed = (pred[], token[]) infeed(token) ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0 } @@ -94,8 +97,9 @@ ENTRY main { const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) + token = token[] after-all() + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a, token[] token) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b, token[] token) } )"; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 01daed4bcd38323bfe33e798a78c2b00b150a1bc..bb105194f1c9001ca4d9fff9174e1ea7e5d8b72a 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) { // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. auto status_or_buffer_assn = BufferAssigner::Run( - hlo_module.get(), MakeUnique(hlo_module.get()), + hlo_module.get(), + absl::make_unique(hlo_module.get()), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return /*alignment=*/1; }); ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index dac416e1c78c2f60d458480c5062f48b77d4878d..e2c7af541eede5265f274c72f55305549f059839 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -32,7 +32,8 @@ ENTRY main { {{{1, 2}, {1001, 1002}, {2001, 2002}}, {{2, 1}, {2001, 3002}, {2001, 2002}}}) - outfeed = token[] outfeed(f32[2,3,2] const_a) + token = token[] after-all() + outfeed = token[] outfeed(f32[2,3,2] const_a, token) ROOT root = () tuple() } )"; @@ -53,6 +54,33 @@ CHECK: private constant [48 x i8] /*match_optimized_ir=*/false); } +TEST_F(CpuOutfeedTest, OutfeedTokenInTuple) { + const string hlo_text = R"( +HloModule OutfeedTokenInTuple + +ENTRY main { + const = f32[] constant(42) + epoch = token[] after-all() + outfeed.tok = token[] outfeed(const, epoch) + ROOT root = (token[], f32[]) tuple(outfeed.tok, const) +} +)"; + + string filecheck_pattern = R"( +CHECK: Outfeed +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 3274be8d9dbfaa55e250748a389ad34fdeb81922..1bd4b59dd604687589eee061d34aa9ca94f6d700 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "absl/algorithm/container.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -422,12 +423,12 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support, std::vector TileVariable::Get() const { std::vector result; - c_transform(storage_, std::back_inserter(result), - [&](VectorVariable vect_var) { return vect_var.Get(); }); + absl::c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); return result; } -void TileVariable::Set(tensorflow::gtl::ArraySlice value) { +void TileVariable::Set(absl::Span value) { CHECK_EQ(value.size(), storage_.size()); for (int64 i = 0, e = value.size(); i < e; i++) { storage_[i].Set(value[i]); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index c728f6df0aef83e6ddc6c932a347f14da06d9d0d..5690d2be2fe3e21c96b51a5226e0b29148217fd1 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -324,7 +324,7 @@ class TileVariable { std::vector initial_value); std::vector Get() const; - void Set(tensorflow::gtl::ArraySlice value); + void Set(absl::Span value); private: std::vector storage_; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index 47543b2082f55cf7b8cf60f1c5bbb16a0a609912..b9e47f5aade3334bece28643e6e32ecfce3bf67b 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -37,7 +37,7 @@ void XfeedQueueManager::Reset() { } void XfeedQueueManager::EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers) { + absl::Span buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index b4ace232607e14fbfec01d48946f0031d96cd027..990ff94ba2338cb663b655ca3106bda83ab718a3 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -22,10 +22,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -63,8 +63,7 @@ class XfeedQueueManager { // called when the buffer will no longer be accessed by the XfeedManager, // either as a result of a call to Reset or because the runtime has dequeued // and used the buffer. - void EnqueueBuffersAtomically( - tensorflow::gtl::ArraySlice buffers); + void EnqueueBuffersAtomically(absl::Span buffers); // Blocks until the queue is non-empty, then returns the buffer at the head of // the queue. Sets the current buffer to be the returned buffer. It is an diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index 56b28fd22da1ea6bc19f98e76f0f2ef4044cd3af..c326beb899f9a434d772c0fda032efc9113b6f42 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -29,7 +29,7 @@ class Defuser : public HloPassInterface { public: Defuser() {} ~Defuser() override {} - tensorflow::StringPiece name() const override { return "defuser"; } + absl::string_view name() const override { return "defuser"; } // Run defusion on the given module. Returns whether the module was // changed. diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4b5bfdd70d5a614b9890b4d7bf050f7..ba2a674d9af547ad574ae49e1e87f3afcaf6112a 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,31 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + absl::string_view name() const override { return "control-dep-remover"; } + + StatusOr Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); pipeline_.AddPass(); diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index cc1695b7f863805e0b483478639c17cb9061310a..7be70add2f7566376b3179740e411d6341badf7c 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -33,7 +33,7 @@ namespace xla { class Despecializer : public HloPassInterface { public: Despecializer(); - tensorflow::StringPiece name() const override { return "despecializer"; } + absl::string_view name() const override { return "despecializer"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index e228bb56bce8febcca28ae171f6de90973d020ab..edbcb25247421cdb50a845df1ec8b1851970efe3 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -25,7 +25,7 @@ namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors) + absl::Span stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} @@ -36,9 +36,8 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( - "Failed to allocate request for %s (%lluB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, - device_ordinal); + "Failed to allocate request for %s (%uB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } return OwningDeviceMemory(result, device_ordinal, this); } @@ -61,12 +60,12 @@ StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( } if (device_ordinal >= stream_executors_.size()) { return InvalidArgument( - "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + "device ordinal value (%d) >= number of devices (%u)", device_ordinal, stream_executors_.size()); } if (stream_executors_[device_ordinal] == nullptr) { return NotFound("Device %s:%d present but not supported", - platform()->Name().c_str(), device_ordinal); + platform()->Name(), device_ordinal); } return stream_executors_[device_ordinal]; } diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index d87b86caf0d3acaa5bf9a455cff2315cedb2496d..a2308ee7a4137bbafe9804c30e33cc68d4628588 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -80,7 +80,7 @@ class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( const se::Platform* platform, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span stream_executors); StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 2172ae0a29626660e8abd29a789e0baa3831519d..3e7373adc5ab8a60fd18348ce2477175aaaa8fd4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -28,14 +28,14 @@ template Status DfsHloVisitorBase::HandleElementwiseUnary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template Status DfsHloVisitorBase::HandleElementwiseBinary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 86d57581f84920e8005e8f3c420e7488fc095434..5761573791d90e45c65b55124a4bae3c5b929ef1 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,14 +19,14 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -107,6 +107,7 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -208,7 +209,6 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; - virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 617a5a2eb4796d8003099e39e3d26389e532e954..4cd10ab06cd3b804406607212d3f3c316d6cff95 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleAllToAll(HloInstructionPtr crs) override { - return DefaultAction(crs); + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); @@ -106,9 +109,6 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } - Status HandleHostCompute(HloInstructionPtr host_compute) override { - return DefaultAction(host_compute); - } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 12faed69677cd99c6ed82c8d13dad3138d9461b7..09cb10d6ee579111b6e0cdb460b9af2b95d090db 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -136,6 +136,7 @@ Status DecomposeBatchDot(HloInstruction* dot) { dot_dnums.add_rhs_contracting_dimensions(0); auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + dot_r2->set_precision_config(dot->precision_config()); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 1959b687f16d6909a3283021c8635b3e65e6e412..fc38e317001695921d20f9bbe5775e61a8eeaa45 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface { DotDecomposer(bool decompose_batch_dot = true) : decompose_batch_dot_(decompose_batch_dot) {} ~DotDecomposer() = default; - tensorflow::StringPiece name() const override { return "dot_decomposer"; } + absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 2e9d6be2de4a2ab918d9a5ea4881ad3fd036792e..4bb1e071d8da75d0219d0b5cc9a6d16f1750a191 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -21,11 +21,15 @@ limitations under the License. #include // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -38,17 +42,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrCat; namespace { @@ -203,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } // namespace StatusOr ElementalIrEmitter::EmitUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { if (op->opcode() == HloOpcode::kCopy) { return operand_value; } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || @@ -217,7 +220,7 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( } StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -229,14 +232,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( - operand_value->getType(), 0)), + ICmpNE(operand_value, + llvm::ConstantInt::get(operand_value->getType(), 0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsIntegralType(to_type)) { - return b_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -252,19 +255,17 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, b_->CreateSIToFP(operand_value, to_ir_component_type), - nullptr); + op, SIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, b_->CreateUIToFP(operand_value, to_ir_component_type), - nullptr); + op, UIToFP(operand_value, to_ir_component_type), nullptr); } } return Unimplemented("conversion from primitive type %s to %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -275,14 +276,13 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -292,10 +292,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (is_signed) { auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpSGE(operand_value, zero); - return b_->CreateSelect(cmp, operand_value, - b_->CreateNeg(operand_value)); + auto cmp = ICmpSGE(operand_value, GetZero(type)); + return Select(cmp, operand_value, Neg(operand_value)); } else { return operand_value; } @@ -307,44 +305,37 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( {operand_value->getType()}, b_); } case HloOpcode::kSign: { - bool is_signed = - primitive_util::IsSignedIntegralType(op->shape().element_type()); + CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) + << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpEQ(operand_value, zero); - if (is_signed) { - auto ashr = - b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); - } else { - return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); - } + auto cmp = ICmpEQ(operand_value, GetZero(type)); + auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), Or(ashr, 1)); } case HloOpcode::kNegate: - return b_->CreateNeg(operand_value); + return Neg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt( - b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return b_->CreateNot(operand_value); + return Not(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } default: return Unimplemented("unary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -361,8 +352,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( - to_component_type, module_)), + FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } if (from_type == BF16) { @@ -378,26 +369,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateFCmpUNE( - operand_value, - llvm::ConstantFP::get(operand_value->getType(), 0.0)), + FCmpUNE(operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return b_->CreateFPToSI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToSI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return b_->CreateFPToUI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToUI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -408,14 +398,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -453,11 +442,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(operand_value, zero); - auto olt = b_->CreateFCmpOLT(operand_value, zero); - return b_->CreateSelect( - oeq, zero, - b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + auto oeq = FCmpOEQ(operand_value, zero); + auto olt = FCmpOLT(operand_value, zero); + return Select(oeq, zero, + Select(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { @@ -467,24 +455,24 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + auto not_infinite = FCmpONE(abs_value, infinity); return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return b_->CreateFNeg(operand_value); + return FNeg(operand_value); case HloOpcode::kReal: return operand_value; case HloOpcode::kImag: return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = primitive_util::IsComplexType(input_type) @@ -496,12 +484,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -509,14 +496,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = b_->CreateFAdd(a, one); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), - b_->CreateFMul(b, b)); + auto a_plus_one = FAdd(a, one); + auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -530,11 +515,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex(op, - b_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - b_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex( + op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), + FPCast(EmitExtractImag(operand_value), to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -544,8 +527,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), - b_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -556,8 +538,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); - auto imag_result = b_->CreateFMul(exp_a, sin_b); + auto real_result = FSub(FMul(exp_a, cos_b), one); + auto imag_result = FMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -572,14 +554,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), - b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); + return EmitComposeComplex(op, + FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)), + FMul(sin_a, FSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -595,14 +576,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), - b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); + return EmitComposeComplex(op, + FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)), + FMul(cos_a, FSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -630,74 +610,63 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = - b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( - b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); - auto real_num = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = + FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = FMul(cos_b, cos_b); + auto sin_b_sq = FMul(sin_b, sin_b); + auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = FMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); + FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = b_->CreateFMul( - cos_b_sin_b, - b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), - b_->CreateFDiv(imag_num, denom)); + FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = FMul( + cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, FDiv(real_num, denom), + FDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(cplx_abs, zero); + return Select( oeq, EmitComposeComplex(op, zero, zero), - EmitComposeComplex( - op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); + EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), + FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex(op, - b_->CreateFNeg(EmitExtractReal(operand_value)), - b_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), + FNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || operand_type == PRED) { @@ -712,21 +681,20 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return b_->CreateFAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateFSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateFMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return b_->CreateFDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return b_->CreateFRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -763,66 +731,52 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex(op, - b_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex(op, - b_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))), + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero); + return Select( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), - EmitComposeComplex( - op, - b_->CreateFDiv( - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq), - b_->CreateFDiv( - b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq))); + EmitComposeComplex(op, + FDiv(FAdd(FMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq), + FDiv(FSub(FMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas @@ -832,21 +786,19 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return b_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return b_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -858,45 +810,43 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = b_->CreateFMul(one_half, c); + auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = b_->CreateFNeg(d); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = b_->CreateFMul(one_half, d); - auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), - b_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), - b_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, - llvm::Value* x) const { + llvm::Value* x) { if (prim_type != F32) { // TODO(b/34339814): Implement inverse erf for F64. return Unimplemented( @@ -906,12 +856,12 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, auto getFloat = [&](const float f) { return llvm::ConstantFP::get(b_->getFloatTy(), f); }; - auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, + auto multiply_add = [&](absl::Span coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); - coefficients.pop_front(); + coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), getFloat(coefficient)); } return p; }; @@ -931,25 +881,24 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::log, {b_->getFloatTy()}); - llvm::Value* w = b_->CreateFNeg(b_->CreateCall( - logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), - b_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg( + Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); + FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); // Handle true BB. SetToFirstInsertPoint(if_data.true_block, b_); { - llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); - tensorflow::gtl::ArraySlice lq{ + llvm::Value* lw = FSub(w, getFloat(2.5f)); + absl::Span lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; llvm::Value* p = multiply_add(lq, lw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } // Handle false BB. @@ -958,76 +907,73 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - llvm::Value* gw = - b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); - tensorflow::gtl::ArraySlice gq{ + llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); + absl::Span gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; llvm::Value* p = multiply_add(gq, gw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } SetToFirstInsertPoint(if_data.after_block, b_); - llvm::Value* p = b_->CreateLoad(p_addr); - return b_->CreateFMul(p, x); + llvm::Value* p = Load(p_addr); + return FMul(p, x); } -StatusOr ElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) { // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, b_->CreateFSub(one, value)); + return EmitErfInv(prim_type, FSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. - TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, b_->CreateFAdd(x, one))); + TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = - b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); + auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( + auto x_is_small = FCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1035,40 +981,40 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = b_->CreateFSub(exp_x, one); + auto for_large_x = FSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = b_->CreateFAdd(x, x); - auto x_squared_over_two = b_->CreateFMul(x_squared, half); - auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); + auto x_squared = FAdd(x, x); + auto x_squared_over_two = FMul(x_squared, half); + auto for_small_x = FAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( - abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + auto x_is_small = + FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return Unimplemented("atan2"); } StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return Unimplemented("tanh"); } StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { + const HloInstruction* hlo, llvm::Value* x) { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } @@ -1099,23 +1045,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 1); +} + +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 0); +} + +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( + integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get( + integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) { + return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); +} + +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) { + return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())), + ICmpEQ(rhs, GetMinusOne(rhs->getType()))); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer division overflow behavior: + // + // X / 0 == -1 + // INT_SMIN /s -1 = INT_SMIN + + if (!is_signed) { + llvm::Value* udiv_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = UDiv(lhs, safe_rhs); + return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = SDiv(lhs, safe_rhs); + + return Select( + has_zero_divisor, GetMinusOne(lhs->getType()), + Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div)); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer remainder overflow behavior: + // + // X % 0 == X + // INT_SMIN %s -1 = 0 + + if (!is_signed) { + llvm::Value* urem_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = URem(lhs, safe_rhs); + return Select(urem_is_unsafe, lhs, safe_rem); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = SRem(lhs, safe_rhs); + + return Select( + has_zero_divisor, lhs, + Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return b_->CreateAdd(lhs_value, rhs_value); + return Add(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateSub(lhs_value, rhs_value); + return Sub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateMul(lhs_value, rhs_value); + return Mul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) - : b_->CreateUDiv(lhs_value, rhs_value); + return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: - return is_signed ? b_->CreateSRem(lhs_value, rhs_value) - : b_->CreateURem(lhs_value, rhs_value); + return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, rhs_value, b_); @@ -1143,11 +1169,11 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return b_->CreateAnd(lhs_value, rhs_value); + return And(lhs_value, rhs_value); case HloOpcode::kOr: - return b_->CreateOr(lhs_value, rhs_value); + return Or(lhs_value, rhs_value); case HloOpcode::kXor: - return b_->CreateXor(lhs_value, rhs_value); + return Xor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1156,43 +1182,43 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. case HloOpcode::kShiftRightArithmetic: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateAShr(lhs_value, rhs_value), + AShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateShl(lhs_value, rhs_value), + Shl(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateLShr(lhs_value, rhs_value), + LShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const { + int64 operand_no) { CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() << " is not elementwise."; @@ -1233,7 +1259,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) { TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, @@ -1251,17 +1277,17 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // Perform the division using the float type with the same number of bits // as the raw value to avoid overflow. if (raw_value_size_in_bits == 32) { - elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); - elem_value = b_->CreateFDiv( - elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + elem_value = UIToFP(elem_value, b_->getFloatTy()); + elem_value = FDiv(elem_value, + llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); } else { - elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); - elem_value = b_->CreateFDiv( + elem_value = UIToFP(elem_value, b_->getDoubleTy()); + elem_value = FDiv( elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); } if (elem_ir_ty != elem_value->getType()) { - elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + elem_value = FPTrunc(elem_value, elem_ir_ty); } } @@ -1269,9 +1295,7 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( switch (hlo->random_distribution()) { case RNG_UNIFORM: { if (elem_ir_ty->isFloatingPointTy()) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), - a_or_mean); + return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean); } else { // To generate a uniform random value in [a, b) from a raw random sample // in range [0, 2^N), we let range = b - a and return @@ -1284,22 +1308,21 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // the same cost as if the whole warp were to re-sample. So an // efficient re-sampling implementation on GPU would need to do // nontrivial work to share entropy between threads in the warp. - auto range = b_->CreateSub(b_or_sigma, a_or_mean); - return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + auto range = Sub(b_or_sigma, a_or_mean); + return Add(a_or_mean, URem(elem_value, range)); } } case RNG_NORMAL: { TF_ASSIGN_OR_RETURN( llvm::Value * r, - EmitErfcInv(elem_prim_ty, - b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), - elem_value))); - return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return FAdd(FMul(r, b_or_sigma), a_or_mean); } default: return InvalidArgument( "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + RandomDistribution_Name(hlo->random_distribution())); } } @@ -1414,8 +1437,7 @@ std::array CalculateSampleValues( // Precondition: the RNG instruction is not fused. llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { VLOG(3) << "Using philox RNG algorithm"; CHECK(!hlo->IsFused()); // A random number generated by the per module random number generator. @@ -1438,7 +1460,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Load the global state variable for the Philox RNG algorithm. llvm::GlobalVariable* rng_state_ptr = llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); - llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value"); // Build and return the elemental IR generator to generate a random value for // the element corresponding to the current thread. @@ -1464,8 +1486,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // element within the sample. llvm::Value* elems_per_sample_value = llvm::ConstantInt::get(index_ty, elems_per_sample); - llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); - llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value); std::array counter_values = CalculateSampleValues( sample_idx, hlo_random_value, global_random_number, rng_state, b_); @@ -1473,18 +1495,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Store the four counter_values into the sample_address alloca so we can // load the elem_offset'th one below. for (int idx = 0; idx < 4; ++idx) { - b_->CreateStore(counter_values[idx], - b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + Store(counter_values[idx], + InBoundsGEP(sample_address, b_->getInt32(idx))); } llvm::Type* int64_ty = b_->getInt64Ty(); CHECK(elems_per_sample == 2 || elems_per_sample == 4); llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; // Retrieve the raw value for the current element from the current sample. - llvm::Value* raw_elem_value = b_->CreateLoad( - b_->CreateInBoundsGEP( - b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), - elem_offset), + llvm::Value* raw_elem_value = Load( + InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), "raw_elem_value"); return ConvertValueForDistribution(hlo, operand_to_generator, index, @@ -1495,7 +1516,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1505,14 +1526,14 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), - on_true_value, on_false_value); + return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, + on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1531,14 +1552,14 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); } else { return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); + PrimitiveType_Name(prim_type)); } } StatusOr ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const { + const llvm_ir::IrArray::Index& target_index) { const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; @@ -1560,9 +1581,9 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -1577,9 +1598,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - b_->CreateCondBr( - b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), - true_block, false_block); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, + false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. @@ -1592,11 +1612,10 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = - b_->CreateSub(source_index[concat_dim], concat_dim_size); + source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); } - b_->CreateUnreachable(); + Unreachable(); b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1604,7 +1623,7 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); @@ -1621,7 +1640,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); CHECK_GE(largest_valid_start_index, 0); @@ -1641,7 +1660,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = Add(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1649,7 +1668,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( StatusOr ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const Shape& operand_shape = hlo->operand(0)->shape(); const Shape& indices_shape = hlo->operand(1)->shape(); const Shape& output_shape = hlo->shape(); @@ -1672,22 +1691,21 @@ StatusOr ElementalIrEmitter::EmitElementalGather( std::vector operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { - int64 output_window_dim = - dim_numbers.output_window_dims(operand_index_dim++); + int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); operand_to_output_dim[i] = output_window_dim; operand_index.push_back(index[output_window_dim]); } } - // This is the index of the index vector in the gather_indices tensor. + // This is the index of the index vector in the start_indices tensor. IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } @@ -1699,8 +1717,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - b_->CreateSExtOrTrunc(index_component, index_type); - int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim); + SExtOrTrunc(index_component, index_type); + int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. // This means we set the iteration index to 0, so for the purpose of the @@ -1723,8 +1741,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = b_->CreateAdd( - operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_index[operand_dim] = + Add(operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { @@ -1748,7 +1766,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const HloInstruction* input_hlo = hlo->operand(0); const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); @@ -1771,7 +1789,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); int64 largest_valid_start_index = @@ -1787,14 +1805,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + + slice_intersection = + And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = + And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); } // Emit: @@ -1811,26 +1829,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = Sub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - b_->CreateStore(true_value, ret_value_addr); + Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - b_->CreateStore(false_value, ret_value_addr); + Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const { + const llvm_ir::IrArray::Index& padded_index) { auto index = padded_index; llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { @@ -1838,26 +1856,22 @@ StatusOr ElementalIrEmitter::EmitElementalPad( return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = - b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = b_->CreateAnd(in_bounds, - b_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = b_->CreateAnd( + index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = + And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); + in_bounds = And( in_bounds, - b_->CreateICmpEQ( + ICmpEQ( index_typed_const(0), - b_->CreateURem(index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = b_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), + URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); + index[i] = + SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); } // if (in_bounds) { @@ -1873,26 +1887,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - b_->CreateStore(operand_value, ret_value_addr); + Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(padding_value, ret_value_addr); + Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const { + const llvm_ir::IrArray::Index& dot_result_index) { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -1920,8 +1934,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); - b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); @@ -1943,42 +1956,37 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = b_->CreateFSub( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = b_->CreateFAdd( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = b_->CreateInsertValue( + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = InsertValue( current_accumulator, - b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), - {0}); - next_accumulator = b_->CreateInsertValue( + FAdd(EmitExtractReal(current_accumulator), product_real), {0}); + next_accumulator = InsertValue( next_accumulator, - b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), - {1}); + FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = b_->CreateFAdd(current_accumulator, - b_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); } else { - next_accumulator = - b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); + next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); } - b_->CreateStore(next_accumulator, accumulator_alloca); + Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); - return b_->CreateLoad(accumulator_alloca); + return Load(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2072,10 +2080,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = b_->CreateSub( - llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_index[dim] = + Sub(llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } return operand_to_generator.at(operand)(source_index); }; @@ -2089,6 +2097,61 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr { + auto* iota = Cast(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + Shape component_shape = + ShapeUtil::ElementIsComplex(iota->shape()) + ? ShapeUtil::ComplexComponentShape(iota->shape()) + : iota->shape(); + PrimitiveType component_element_type = component_shape.element_type(); + llvm::Value* iota_result; + if (ShapeUtil::ElementIsIntegral(component_shape)) { + iota_result = b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape)) + << component_element_type; + llvm::Type* float_ir_type; + if (component_element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(component_element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (component_element_type == BF16) { + iota_result = EmitF32ToBF16(float_val, b_); + } else { + iota_result = float_val; + } + } + if (ShapeUtil::ElementIsComplex(iota->shape())) { + return EmitComposeComplex(iota, iota_result, nullptr); + } else { + return iota_result; + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2154,28 +2217,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); }; } } -llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return b_->CreateExtractValue(value, {0}); +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) { + return ExtractValue(value, {0}); } -llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return b_->CreateExtractValue(value, {1}); +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { + return ExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const { + llvm::Value* imag) { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = b_->CreateInsertValue( - llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + auto complex = + InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = b_->CreateInsertValue(complex, imag, {1}); + complex = InsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1598a4dd85632cfa9835a81a21eddff3e57bfa1f..d3e2acaabd4f602171def70ccd3d4fd5adce0d0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -23,12 +23,13 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { -class ElementalIrEmitter { +class ElementalIrEmitter : public IrBuilderMixin { public: using HloToElementGeneratorMap = std::unordered_map; @@ -40,100 +41,114 @@ class ElementalIrEmitter { virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value) const; + llvm::Value* operand_value); virtual StatusOr EmitBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. virtual llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); - llvm::IRBuilder<>* b() const { return b_; } - llvm::Module* module() const { return module_; } + llvm::IRBuilder<>* b() { return b_; } + + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return b_; } + + llvm::Module* module() { return module_; } protected: - virtual StatusOr EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); + + virtual StatusOr EmitFloatUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + llvm::Value* IsZero(llvm::Value* v); + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* GetZero(llvm::Type* type); + llvm::Value* GetOne(llvm::Type* type); + llvm::Value* GetIntSMin(llvm::Type* type); + llvm::Value* GetMinusOne(llvm::Type* type); + + llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); + llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); - virtual StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); virtual StatusOr EmitErfInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x) const; + llvm::Value* x); - virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; - virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; + llvm::Value* imag); // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its @@ -142,50 +157,50 @@ class ElementalIrEmitter { // Precondition: `hlo` is an elementwise op. llvm_ir::IrArray::Index ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const; + int64 operand_no); // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } + virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const; + const llvm_ir::IrArray::Index& target_index); StatusOr EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const; + const llvm_ir::IrArray::Index& padded_index); StatusOr EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const; + const llvm_ir::IrArray::Index& dot_result_index); llvm::IRBuilder<>* const b_; @@ -200,13 +215,13 @@ class ElementalIrEmitter { // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index addb016b0481b744ff42ba827104099b6cdc3bb9..1b3be199f632a2aa6bd2c5a3820c7c5ce9b1382e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -24,12 +24,11 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class ElementalIrEmitterExecutionTest : public HloTestBase { protected: - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd75847d0c0e737957401b8efc420d504a3c0706..47c56e2f7fbd9f53be6a2b189c5c36cf4fdcdccb 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -22,16 +24,14 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" -using tensorflow::gtl::ArraySlice; namespace xla { StatusOr> Executable::ExecuteOnStreams( - ArraySlice run_options, - ArraySlice> arguments) { + absl::Span run_options, + absl::Span> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); std::vector return_values; @@ -62,7 +62,7 @@ StatusOr> Executable::ExecuteOnStreams( StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - ArraySlice arguments) { + absl::Span arguments) { se::Stream* stream = run_options->stream(); std::unique_ptr timer; if (profile != nullptr) { @@ -76,8 +76,8 @@ StatusOr Executable::ExecuteOnStreamWrapper( std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? absl::make_unique(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr return_value = @@ -154,9 +154,9 @@ Status Executable::DumpHloSnapshot() { const string& directory_path = module_config().debug_options().xla_dump_executions_to(); const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", module.id(), - module.entry_computation_name().c_str(), ++execution_count_); + string filename = + absl::StrFormat("computation_%d__%s__execution_%d", module.id(), + module.entry_computation_name(), ++execution_count_); return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 98eaeee30a693211ae564a5ef3c373f0364bef11..3a6780f2a67f230cae626ea00cfbf93b4e60d968 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -18,7 +18,10 @@ limitations under the License. #include #include +#include +#include "absl/types/span.h" +#include "absl/types/variant.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -26,18 +29,33 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" namespace xla { +// ExecutionOutput encapsulates the output buffers of a execution and the +// leftover buffers to be released by the caller. +struct ExecutionOutput { + ExecutionOutput(ScopedShapedBuffer result, + std::vector to_be_released) + : result(std::move(result)), to_be_released(std::move(to_be_released)) {} + ScopedShapedBuffer result; + + // Leftover buffers for the caller to release. Elements in this list are + // donated input memory buffers that are not reused by XLA as outputs. + std::vector to_be_released; +}; + // A given platform's compiler will produce an Executable -- this is a uniform // interface that is used for launching compiled programs across platforms. class Executable { @@ -63,25 +81,46 @@ class Executable { // Returns a shaped buffer containing the result of the computation. virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) = 0; + absl::Span arguments) = 0; + + // Starts the given program executing on the given stream/executor. + // + // `arguments` are ShapeTree containing the input parameters. For each element + // in the shape tree, if the element holds the ownership of the memory, it is + // considered donated and XLA will potentially reuse it as output buffers. For + // all donated inputs, XLA is also responsible for freeing them. + // + // If an input is donated to XLA but is not reused as output, it is returned + // as an leftover buffer for the caller to release. + virtual StatusOr ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } + + virtual StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); + } // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. virtual StatusOr> ExecuteOnStreams( - tensorflow::gtl::ArraySlice - run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments); + absl::Span run_options, + absl::Span> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any // Execute* API call that takes a hlo_execution_profile argument, but must be @@ -97,7 +136,7 @@ class Executable { // given ExecutionProfile if non-null. StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 228c3fac95c3114484637bd93ec51c60b44403cc..997db7c058af6da8ecff399769b85b803e2e5785 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend, tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( - handle, - MakeUnique(backend, std::move(streams), profile, result)); + handle, absl::make_unique(backend, std::move(streams), + profile, result)); CHECK(inserted.second); ExecutionHandle execution_handle; @@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } handle_to_execution_.erase(handle.handle()); @@ -78,7 +78,7 @@ StatusOr ExecutionTracker::Resolve( tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } return it->second.get(); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index d3efab3614912e4b0c2c8aa3b80277c326382ed0..3cccec9862e0f92df478006939552099868121b9 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -28,7 +28,7 @@ namespace xla { // points-to analysis (see b/36865746 for details). class FlattenCallGraph : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + absl::string_view name() const override { return "flatten-call-graph"; } // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index e3a42d0d06be9e4c9ef96ed2e6ff5daa8eebaf3e..cb86c9857936f21d9d2ac6bc22c725b89cca6482 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -24,88 +25,87 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( - HloInstruction* gather_indices, int64 index_vector_dim) { - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices, int64 index_vector_dim) { + const Shape& start_indices_shape = start_indices->shape(); - if (gather_indices_shape.dimensions_size() == index_vector_dim) { - return gather_indices; + if (start_indices_shape.dimensions_size() == index_vector_dim) { + return start_indices; } - if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { - return gather_indices; + if (index_vector_dim == (start_indices_shape.dimensions_size() - 1)) { + return start_indices; } std::vector permutation; - permutation.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + permutation.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { permutation.push_back(i); } } permutation.push_back(index_vector_dim); - return MakeTransposeHlo(gather_indices, permutation); + return MakeTransposeHlo(start_indices, permutation); } -// Canonicalizes the gather_indices tensors so that we only have deal with some +// Canonicalizes the start_indices tensors so that we only have deal with some // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. static StatusOr CanonicalizeGatherIndices( - HloInstruction* gather_indices, int64 index_vector_dim) { + HloInstruction* start_indices, int64 index_vector_dim) { // Transpose the non-index-vector dimensions to the front. TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_gather_indices, - TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + HloInstruction * transposed_start_indices, + TransposeIndexVectorDimToLast(start_indices, index_vector_dim)); bool indices_are_scalar = - index_vector_dim == gather_indices->shape().dimensions_size(); + index_vector_dim == start_indices->shape().dimensions_size(); - // The number of dimensions in gather_indices that are index dimensions. - const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; + // The number of dimensions in start_indices that are index dimensions. + const int64 index_dims_in_start_indices = indices_are_scalar ? 0 : 1; - // If there is only one index (i.e. gather_indices has rank 1 and this gather + // If there is only one index (i.e. start_indices has rank 1 and this gather // is really just a dynamic slice) add a leading degenerate dimension for // uniformity. Otherwise create a "collapsed" leading dimension that subsumes // all of the non-index-vector dimensions. - const Shape& shape = transposed_gather_indices->shape(); - if (shape.dimensions_size() == index_dims_in_gather_indices) { - return PrependDegenerateDims(transposed_gather_indices, 1); + const Shape& shape = transposed_start_indices->shape(); + if (shape.dimensions_size() == index_dims_in_start_indices) { + return PrependDegenerateDims(transposed_start_indices, 1); } else { - // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // Collapse all but the dimensions (0 or 1) in start_indices containing the // index vectors. return CollapseFirstNDims( - transposed_gather_indices, - shape.dimensions_size() - index_dims_in_gather_indices); + transposed_start_indices, + shape.dimensions_size() - index_dims_in_start_indices); } } // Expands out or contracts away the gather dimensions in the accumulator // produced by the while loop. -static StatusOr AdjustGatherDimsInAccumulator( - const Shape& gather_indices_shape, HloInstruction* accumulator, +static StatusOr AdjustBatchDimsInAccumulator( + const Shape& start_indices_shape, HloInstruction* accumulator, int64 index_vector_dim) { - std::vector output_gather_dim_bounds; - output_gather_dim_bounds.reserve(gather_indices_shape.dimensions_size()); - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + std::vector batch_dim_bounds; + batch_dim_bounds.reserve(start_indices_shape.dimensions_size()); + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != index_vector_dim) { - output_gather_dim_bounds.push_back(gather_indices_shape.dimensions(i)); + batch_dim_bounds.push_back(start_indices_shape.dimensions(i)); } } - if (output_gather_dim_bounds.empty()) { - // If output_gather_dim_bounds is empty we must be lowering a (effectively) + if (batch_dim_bounds.empty()) { + // If batch_dim_bounds is empty we must be lowering a (effectively) // dynamic-slice. In that case, there is a leading degenerate gather // dimension that we added to make this special case play well with the // general while loop which we need to remove now. return ElideDegenerateDims(accumulator, {0}); } - return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); + return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); } -// Expand an index vector from the gather_indices tensor into a vector that can +// Expand an index vector from the start_indices tensor into a vector that can // be used to dynamic-slice out of the gather operand. static StatusOr ExpandIndexVectorIntoOperandSpace( HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, @@ -121,10 +121,8 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( std::vector expanded_index_components; for (int i = 0; i < operand_rank; i++) { - int64 index_vector_dim_index = - FindIndex(dim_numbers.gather_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.gather_dims_to_operand_dims_size()) { + int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i); + if (index_vector_dim_index != dim_numbers.start_index_map_size()) { TF_ASSIGN_OR_RETURN( HloInstruction * component_to_concat, MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, @@ -147,10 +145,10 @@ static StatusOr> GatherLoopBody( const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); CHECK_EQ(incoming_loop_state.size(), 3); HloInstruction* const operand = incoming_loop_state[0]; - HloInstruction* const gather_indices = incoming_loop_state[1]; + HloInstruction* const start_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; - bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + bool has_scalar_indices = start_indices->shape().dimensions_size() == 1; CHECK_EQ(has_scalar_indices, dim_numbers.index_vector_dim() == gather.operand(1)->shape().dimensions_size()); @@ -163,24 +161,24 @@ static StatusOr> GatherLoopBody( HloInstruction* index_vector; if (has_scalar_indices) { - // In this case gather_indices has rank 1 and induction_var_as_vector (of + // In this case start_indices has rank 1 and induction_var_as_vector (of // shape {1}) is an index into this rank 1 tensor. TF_ASSIGN_OR_RETURN( index_vector, - MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + MakeDynamicSliceHlo(start_indices, induction_var_as_vector, {1})); } else { - // In this case gather_indices has rank 2 and induction_var_as_vector (of + // In this case start_indices has rank 2 and induction_var_as_vector (of // shape {1}) is an index into just the first dimension of this rank 2 // tensor. TF_ASSIGN_OR_RETURN( - HloInstruction * index_into_gather_indices, + HloInstruction * index_into_start_indices, PadVectorWithZeros(induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); - int64 index_vector_size = gather_indices->shape().dimensions(1); + int64 index_vector_size = start_indices->shape().dimensions(1); TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_2d, - MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + MakeDynamicSliceHlo(start_indices, index_into_start_indices, {1, index_vector_size})); TF_ASSIGN_OR_RETURN(index_vector, @@ -194,26 +192,26 @@ static StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, - gather.gather_window_bounds())); + gather.gather_slice_sizes())); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_with_dims_elided, + HloInstruction* const gathered_slice_with_dims_collapsed, ElideDegenerateDims(gathered_slice, - AsInt64Slice(dim_numbers.elided_window_dims()))); + AsInt64Slice(dim_numbers.collapsed_slice_dims()))); TF_ASSIGN_OR_RETURN( - HloInstruction * gathered_slice_for_update, - PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); + HloInstruction* const gathered_slice_for_update, + PrependDegenerateDims(gathered_slice_with_dims_collapsed, 1)); TF_ASSIGN_OR_RETURN( - HloInstruction * index_vector_into_accumulator, + HloInstruction* const index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, /*zeros_to_append=*/ - gathered_slice_with_dims_elided->shape().dimensions_size())); + gathered_slice_with_dims_collapsed->shape().dimensions_size())); TF_ASSIGN_OR_RETURN( - HloInstruction * updated_accumulator, + HloInstruction* const updated_accumulator, MakeDynamicUpdateSliceHlo(output_accumulator, gathered_slice_for_update, index_vector_into_accumulator)); @@ -221,19 +219,19 @@ static StatusOr> GatherLoopBody( // WhileUtil::MakeCountedLoop functions takes care of the induction variable // and the while loop exit condition. return StatusOr>{ - {operand, gather_indices, updated_accumulator}}; + {operand, start_indices, updated_accumulator}}; } static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice window_bounds, int64 gather_loop_trip_count, + absl::Span slice_sizes, int64 gather_loop_trip_count, const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; - accumulator_state_shape_dims.reserve(1 + window_bounds.size()); + accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - for (int64 i = 0; i < window_bounds.size(); i++) { - if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { - accumulator_state_shape_dims.push_back(window_bounds[i]); + for (int64 i = 0; i < slice_sizes.size(); i++) { + if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + accumulator_state_shape_dims.push_back(slice_sizes[i]); } } return BroadcastZeros(computation, element_type, @@ -241,23 +239,23 @@ static StatusOr CreateGatherLoopAccumulatorInitValue( } // `accumulator` is almost the tensor the gather operation would have produced, -// except that it has the dimensions in the wrong order -- the gather dimensions -// are the major dimensions and the window dimensions are the minor dimensions. +// except that it has the dimensions in the wrong order -- the batch dimensions +// are the major dimensions and the offset dimensions are the minor dimensions. // Fix this up with a transpose. -static StatusOr PermuteGatherAndWindowDims( - HloInstruction* accumulator, ArraySlice output_window_dims, +static StatusOr PermuteBatchAndOffsetDims( + HloInstruction* accumulator, absl::Span offset_dims, int64 output_rank) { std::vector permutation; permutation.reserve(output_rank); - int64 gather_idx_counter = 0; - int64 window_idx_counter = output_rank - output_window_dims.size(); + int64 batch_idx_counter = 0; + int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_window_dim = c_binary_search(output_window_dims, i); - if (is_window_dim) { - permutation.push_back(window_idx_counter++); + bool is_offset_dim = absl::c_binary_search(offset_dims, i); + if (is_offset_dim) { + permutation.push_back(offset_idx_counter++); } else { - permutation.push_back(gather_idx_counter++); + permutation.push_back(batch_idx_counter++); } } @@ -268,11 +266,11 @@ static StatusOr PermuteGatherAndWindowDims( // // We follow the following steps in sequence: // -// 1. We canonicalize the gather_indices tensor such that it has rank +// 1. We canonicalize the start_indices tensor such that it has rank // 2 (i.e. is a matrix) where each row is an index vector into the // operand. // 2. We iterate over the set of indices in the canonicalized -// gather_indices tensor using a while loop, accumulating slices +// start_indices tensor using a while loop, accumulating slices // of the operand tensor into an accumulator using // DynamicUpdateSlice. // 3. The accumulator result from the while loop from (2) is then @@ -287,11 +285,11 @@ static StatusOr PermuteGatherAndWindowDims( // operand = s32[3,3] parameter(0) // indices = s32[2,2] parameter(1) // ROOT gather = s32[2,3,2] gather(operand, indices), -// output_window_dims={1}, -// elided_window_dims={1}, -// gather_dims_to_operand_dims={1}, +// offset_dims={1}, +// collapsed_slice_dims={1}, +// start_index_map={1}, // index_vector_dim=2, -// window_bounds={3, 1} +// slice_sizes={3, 1} // } // // We'd first reshape indices to s32[4,1], where each row is an index @@ -305,8 +303,8 @@ StatusOr GatherExpander::ExpandGather( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); - HloInstruction* gather_indices = gather_instr->mutable_operand(1); - const Shape& gather_indices_shape = gather_indices->shape(); + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); @@ -314,9 +312,9 @@ StatusOr GatherExpander::ExpandGather( gather_instr->gather_dimension_numbers(); int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= gather_indices_shape.dimensions(i); + gather_loop_trip_count *= start_indices_shape.dimensions(i); } } @@ -324,27 +322,27 @@ StatusOr GatherExpander::ExpandGather( return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " "supported. This error occurred for %s.", - gather_instr->ToString().c_str()); + gather_instr->ToString()); } - TF_ASSIGN_OR_RETURN(HloInstruction * canonical_gather_indices, - CanonicalizeGatherIndices( - gather_indices, dim_numbers.index_vector_dim())); + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_start_indices, + CanonicalizeGatherIndices(start_indices, dim_numbers.index_vector_dim())); CHECK_EQ(gather_loop_trip_count, - canonical_gather_indices->shape().dimensions(0)); + canonical_start_indices->shape().dimensions(0)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_init, CreateGatherLoopAccumulatorInitValue( computation, output_shape.element_type(), - gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_slice_sizes(), gather_loop_trip_count, gather_instr->gather_dimension_numbers())); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( computation, gather_loop_trip_count, - {operand, canonical_gather_indices, accumulator_init}, + {operand, canonical_start_indices, accumulator_init}, [&](HloInstruction* indvar, const std::vector& loop_state) { return GatherLoopBody(*gather_instr, indvar, loop_state); @@ -356,13 +354,13 @@ StatusOr GatherExpander::ExpandGather( HloInstruction* accumulator_result = gather_loop_result.back(); TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, - dim_numbers.index_vector_dim())); + HloInstruction* const accumulator_with_batch_dims_decanonicalized, + AdjustBatchDimsInAccumulator(start_indices->shape(), accumulator_result, + dim_numbers.index_vector_dim())); - return PermuteGatherAndWindowDims( - accumulator_with_output_gather_dims_decanonicalized, - AsInt64Slice(dim_numbers.output_window_dims()), output_rank); + return PermuteBatchAndOffsetDims(accumulator_with_batch_dims_decanonicalized, + AsInt64Slice(dim_numbers.offset_dims()), + output_rank); } StatusOr GatherExpander::Run(HloModule* module) { @@ -375,8 +373,8 @@ StatusOr GatherExpander::Run(HloModule* module) { std::vector gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - is_nontrivial_gather); + absl::c_copy_if(computation->instructions(), + std::back_inserter(gather_instrs), is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index c1fc8574da99fff223c7dbb570b4533f76905b9a..7bd9ea598417a931d2df507d472c6a60be05e0bc 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -25,7 +25,7 @@ namespace xla { // nevertheless have a minimum level of support. class GatherExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "gather_expander"; } + absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 020ffcd106862cb2641a9f3bceb70acdd969a458..141dd4d6f10272ce749edc4e91153c365ed322e6 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -28,11 +28,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2147483647,5] parameter(1) ROOT gather = s32[2147483647,3,5] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -55,11 +55,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 0ce2db907b643f3beabd127388370dbe601179e1..4ed91ef18768d09c252d1b73890637227f0ce717 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -42,8 +42,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -163,7 +162,7 @@ Status GenericTransferManager::TransferLiteralFromOutfeed( } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + absl::Span /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 6c1a21587a7ef5199afb93715dc57be5139fbc22..86c8b1c145a25149a25e7b272babc5c858d476af 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,15 +55,13 @@ class GenericTransferManager : public TransferManager { const Shape& literal_shape, MutableBorrowingLiteral literal) override; - Status ResetDevices( - tensorflow::gtl::ArraySlice executors) override; + Status ResetDevices(absl::Span executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bacd2c1f146385dac0ec5978c5f8ec9f463cf550..a68b7a1bef81e369dc1bbcd249642e5b80401c64 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -55,6 +56,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -90,6 +93,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -106,6 +110,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -125,6 +131,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -170,6 +178,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -179,6 +188,12 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -223,6 +238,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", ], @@ -242,6 +259,8 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -256,6 +275,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -336,6 +356,11 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -361,15 +386,20 @@ cc_library( hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ ":backend_configs", + ":buffer_comparator", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -387,6 +417,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -417,7 +448,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -428,6 +459,7 @@ cc_library( srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":gpu_fusible", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -457,12 +489,14 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -480,6 +514,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -503,6 +538,7 @@ cc_library( srcs = ["fusion_merger.cc"], hdrs = ["fusion_merger.h"], deps = [ + ":gpu_fusible", ":instruction_fusion", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -510,6 +546,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -541,6 +579,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", + "@com_google_absl//absl/memory", ], ) @@ -597,6 +636,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains per-platform transfer manager registration @@ -613,9 +653,9 @@ cc_library( ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_schedule", ":gpu_hlo_support_checker", ":gpu_layout_assignment", - ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -666,6 +706,10 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -698,8 +742,8 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -714,6 +758,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -752,39 +797,42 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", ], ) cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], + name = "gpu_hlo_schedule", + srcs = ["gpu_hlo_schedule.cc"], + hdrs = ["gpu_hlo_schedule.h"], deps = [ ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/compiler/xla/service:hlo_scheduling", + "@com_google_absl//absl/memory", ], ) tf_cc_test( - name = "hlo_schedule_test", + name = "gpu_hlo_schedule_test", srcs = [ - "hlo_schedule_test.cc", + "gpu_hlo_schedule_test.cc", ], deps = [ - ":hlo_schedule", + ":gpu_hlo_schedule", ":stream_assignment", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -835,7 +883,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -853,3 +903,57 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "gpu_fusible", + srcs = ["gpu_fusible.cc"], + hdrs = ["gpu_fusible.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:hlo", + ], +) + +tf_cc_test( + name = "gpu_fusible_test", + srcs = ["gpu_fusible_test.cc"], + deps = [ + ":gpu_fusible", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 537295292b6ced72c4b2c456557b3c06e0aa5254..528209abc75777440163c2e1512658b8ad36315b 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -40,7 +40,7 @@ StatusOr> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); - auto buffer_allocations = WrapUnique(new BufferAllocations( + auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { @@ -62,7 +62,7 @@ StatusOr> BufferAllocations::Builder::Build( if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( - "Address of registered buffer %lld must be a multiple of %llx, but " + "Address of registered buffer %d must be a multiple of %x, but " "was %p", i, kEntryParameterAlignBytes, address.opaque()); } @@ -83,7 +83,7 @@ StatusOr> BufferAllocations::Builder::Build( 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " - "multiple of %llx, but was %p", + "multiple of 0x%x, but was %p", kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index f13eab0dd787a2bfa687c991f9d808568360fd24..14186b8faa68ad8492ea4863fcd7bd746e2eae48 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 0000000000000000000000000000000000000000..13c83c9199fb1bbd8b00dbd601afcb677f92bbee --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,204 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + return absl::StrReplaceAll(kF16CompHloText, + {{"SIZE", absl::StrCat(num_elements)}}); +} + +StatusOr F16BufferComparator::Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr F16BufferComparator::CompareEqualImpl( + se::DeviceMemory test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %d vs %d", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr F16BufferComparator::CompareEqual( + se::DeviceMemory test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast(host_ref_buffer[i]); + float original_test = static_cast(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 0000000000000000000000000000000000000000..bf2ba78ceacaea1070830f758c3712b1378bd96f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr CompareEqual(se::DeviceMemory test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33761d1bd8807df225e2cf505303b120e418576f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector& lhs_float, + const std::vector& rhs_float) { + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 5780e0af40699bb6ac2c190c09cd02023fb44db7..9ed523998bf07567133fdac0e40b12b8ce4ea3b0 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } // Execute the true or the false computation depending on the value of the diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7833a4077e6c6ee4960665f37fb01a35530fd302..05448d863dd2cfe69ad70168be40cdea5bc7017f 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -38,8 +37,8 @@ ConvolutionThunk::ConvolutionThunk( const BufferAllocation::Slice& tuple_result_buffer, const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, + int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), @@ -52,6 +51,7 @@ ConvolutionThunk::ConvolutionThunk( output_shape_(output_shape), window_(window), dim_nums_(dim_nums), + feature_group_count_(feature_group_count), algorithm_(algorithm), tensor_ops_enabled_(tensor_ops_enabled) {} @@ -73,8 +73,8 @@ Status ConvolutionThunk::ExecuteOnStream( auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, - stream)); + filter_data, output_data, scratch, window_, dim_nums_, + feature_group_count_, algorithm_config, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d76ca6698dcf462c3c4961ce6a9784822af3a81f..68d67c40c56145a137398540e90b75b33642589f 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -59,7 +59,8 @@ class ConvolutionThunk : public Thunk { const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + const ConvolutionDimensionNumbers& dim_nums, + int64 feature_group_count, int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; @@ -71,19 +72,6 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - class ScratchAllocator; - - Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, - se::DeviceMemory input_data, - const se::dnn::FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const se::dnn::BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const se::dnn::ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, - se::Stream* stream, ScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result); - const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; @@ -98,6 +86,7 @@ class ConvolutionThunk : public Thunk { const Window window_; const ConvolutionDimensionNumbers dim_nums_; + int64 feature_group_count_; int64 algorithm_; bool tensor_ops_enabled_; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index e09cde9abf85454c7a020566cd8c2671ae12ffc3..6e2e330edd4beabe0b395f05b80d57612d63f110 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -54,9 +54,7 @@ namespace gpu { // BatchNormRewriter. class CudnnBatchNormRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "cudnn_batchnorm_rewriter"; - } + absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 7b172812c36bb141787ef3a9285d6f7ce13e343b..bc3c6f72f6799f84169748465d62c3f2a306d5fc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 7348307ec8a7286dfb733d6b9685862b20f11ac9..5c2555148ae5de4a15e5a5f003b4783c64a20e9c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,24 +14,25 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { namespace { +using absl::optional; using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { public: @@ -59,8 +60,8 @@ StatusOr> ScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -128,14 +129,14 @@ std::vector GetAlgorithms(CudnnConvKind kind, string AlgorithmToString(const AlgorithmDesc& algo) { if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + return absl::StrCat(algo.algo_id(), "+TC"); } - return tensorflow::strings::StrCat(algo.algo_id()); + return absl::StrCat(algo.algo_id()); } string NumBytesToString(int64 bytes) { - return tensorflow::strings::StrCat( - tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); } // Acquires a process-global lock on the device pointed to by the given @@ -173,11 +174,18 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional> +StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + HloInstruction* instr) { + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); + CHECK_EQ(input_shape.element_type(), output_shape.element_type()); + // TODO(timshen): for now only check fp16. It can be expanded to other types, + // with some work on the HLO routines. + const bool cross_check_enabled = input_shape.element_type() == xla::F16; + // Don't run this function concurrently on the same GPU. // // This is a bit of a hack and doesn't protect us against arbitrary concurrent @@ -185,6 +193,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // concurrently and then run them sequentially. tensorflow::mutex_lock lock = LockGpu(stream_exec_); + // Make sure any previous activity on this executor is done. We don't want to + // interfere with programs that are still running on the GPU. + if (!stream_exec_->SynchronizeAllActivity()) { + return InternalError("Failed to synchronize GPU for autotuning."); + } + // Create a stream for us to do our work on. se::Stream stream{stream_exec_}; stream.Init(); @@ -197,60 +211,82 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( if (allocator_ != nullptr) { allocator = allocator_; } else { - se_allocator.emplace( - stream_exec_->platform(), - tensorflow::gtl::ArraySlice({stream_exec_})); + se_allocator.emplace(stream_exec_->platform(), + absl::Span({stream_exec_})); allocator = &*se_allocator; } // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - // - // We don't put any data in these buffers, because (in theory, anyway) the - // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - StatusOr maybe_input_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(input_shape)); - StatusOr maybe_filter_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(filter_shape)); - StatusOr maybe_output_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(output_shape)); - if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || - !maybe_output_buf.ok()) { - LOG(WARNING) - << "Couldn't allocate space for input/filter/output of convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; + TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); + + if (cross_check_enabled) { + // Broadcast a constant to the buffer, instead of zeroing the buffer. A + // non-zero constant is useful for the cross checking, because zero-inputs + // may not always reveal the bugs. + const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) { + CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4); + size_t left_over_bytes = buffer.size() % 4; + CHECK_EQ(0, left_over_bytes % 2); + + constexpr float kBroadcastedConstant = 0.1f; + static const Eigen::half halfs[2] = {Eigen::half(kBroadcastedConstant), + Eigen::half(kBroadcastedConstant)}; + uint32 bits; + static_assert(sizeof(bits) == sizeof(halfs), ""); + memcpy(&bits, halfs, sizeof(bits)); + + size_t aligned_size = buffer.size() / 4 * 4; + stream.ThenMemset32(&buffer, bits, aligned_size); + + DeviceMemoryBase left_over( + static_cast(buffer.opaque()) + aligned_size, left_over_bytes); + stream.ThenMemcpy(&left_over, halfs, left_over_bytes); + }; + initialize_f16(input_buf); + initialize_f16(filter_buf); + initialize_f16(output_buf); + } else { + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the + // inputs might affect performance if e.g. the inputs contain denormals, and + // this is easy enough. + stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()); } - DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); - DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); - DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); - - // Although we don't have evidence this matters, zero out the buffers before - // autotuning. It's conceivable that using uninitialized memory as the inputs - // might affect performance if e.g. the inputs contain denormals, and this is - // easy enough. - if (!stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()) - .BlockHostUntilDone() - .ok()) { - LOG(WARNING) - << "Couldn't zero out input/filter/output buffer for convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } + DeviceMemoryBase* result_buf = [&] { + switch (kind) { + case CudnnConvKind::kBackwardFilter: + return &filter_buf; + case CudnnConvKind::kBackwardInput: + return &input_buf; + case CudnnConvKind::kForward: + return &output_buf; + } + }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + optional comparator; + // Use the first algorithm that's supported as reference. There isn't a + // particular reason to use it, as any algorithm sufficies. It doesn't make + // this algorithm considered correct, though. + optional first_algorithm; for (const AlgorithmDesc& alg : GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); @@ -259,13 +295,49 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << instr->ToString(); bool launch_ok = - RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, - AlgorithmConfig(alg), &stream, &profile_result) + RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, input_buf, + filter_buf, output_buf, &scratch_allocator, window, dnums, + feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { + const bool crash_on_checking_failure = + instr->GetModule() + ->config() + .debug_options() + .xla_gpu_crash_on_verification_failures(); + if (comparator.has_value()) { + StatusOr result = comparator->CompareEqual( + se::DeviceMemory(*result_buf)); + if (!result.ok()) { + LOG(ERROR) << "Unable to compare " + << AlgorithmToString(*first_algorithm) << " against " + << AlgorithmToString(alg) << " for " << instr->ToString() + << ": " << result.status(); + CHECK(!crash_on_checking_failure); + } else if (!result.ValueOrDie()) { + LOG(ERROR) << "Results mismatch between different convolution " + "algorithms. This is likely a bug in convolution, or " + "an excessive loss of precision in convolution. " + << instr->ToString() << " for " + << AlgorithmToString(*first_algorithm) << " vs " + << AlgorithmToString(alg); + CHECK(!crash_on_checking_failure); + } + } else if (cross_check_enabled) { + auto comp = F16BufferComparator::Create( + se::DeviceMemory(*result_buf), compiler_, allocator, + &stream); + if (comp.ok()) { + comparator.emplace(comp.ConsumeValueOrDie()); + first_algorithm.emplace(alg); + } else { + LOG(ERROR) << "Fail to initialize buffer comparator: " + << comp.status() << ", instruction: " << instr->ToString(); + CHECK(!crash_on_checking_failure); + } + } int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " succeeded, taking " << profile_result.elapsed_time_in_ms() @@ -292,9 +364,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( best_result_bytes_used); } - LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() - << " failed. Falling back to default algorithm."; - return nullopt; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString()); } StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( @@ -305,28 +378,33 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( const auto& lhs_shape = instr->operand(0)->shape(); const auto& rhs_shape = instr->operand(1)->shape(); const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional> alg_scratch_and_tc; + StatusOr> alg_scratch_and_tc; if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + alg_scratch_and_tc = + PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, instr->window(), + instr->convolution_dimension_numbers(), + instr->feature_group_count(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr); + instr->convolution_dimension_numbers(), instr->feature_group_count(), + instr); } else if (call_target == kCudnnConvBackwardFilterCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + instr->window(), instr->convolution_dimension_numbers(), + instr->feature_group_count(), instr); } else { LOG(FATAL) << "Unknown custom call target for cudnn conv: " << instr->ToString(); } - if (!alg_scratch_and_tc.has_value()) { + if (!alg_scratch_and_tc.ok()) { + LOG(ERROR) << alg_scratch_and_tc.status(); return false; } @@ -334,7 +412,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = + alg_scratch_and_tc.ConsumeValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) @@ -352,14 +431,9 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( backend_config.set_algorithm(algorithm); backend_config.set_tensor_ops_enabled(tensor_ops_enabled); - HloInstruction* new_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - instr->custom_call_target())); - new_call->set_window(instr->window()); - new_call->set_convolution_dimension_numbers( - instr->convolution_dimension_numbers()); + HloInstruction* new_call = computation->AddInstruction( + instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0), + instr->mutable_operand(1)})); TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94afd2075a006899f0f6bcf64352e5e99..0cb01161b023b900c8c4b1386b679fe2bd5db802 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,11 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -34,10 +35,11 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-algorithm-picker"; } @@ -46,13 +48,15 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional> PickBestAlgorithm( + StatusOr> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 905b5ee8767d0fa0514c7f1abf83bc089cd08045..9bf721ecd2ad938e71f88a6fc65cd2d3bd25161e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -59,6 +59,11 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + // TODO(b/31709653): Figure out if we can use grouped convolutions also on + // backward filter. + if (conv->feature_group_count() > 1) { + return no_match_result; + } // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -218,6 +223,12 @@ std::tuple MatchBackwardInput( const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + // TODO(b/31709653): Figure out if we can use grouped convolutions also on + // backward input. + if (conv->feature_group_count() > 1) { + return no_match_result; + } + // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); @@ -234,6 +245,23 @@ std::tuple MatchBackwardInput( << "Backward input convolution should reverse all kernel dimensions."; return no_match_result; } + } else if (reverse_filter->IsConstant()) { + // If the filter is a constant, we're willing to pattern-match to a + // backwards-input conv, on the theory that + // + // a) reversing a constant is free, and + // b) even if the user specified this filter as reverse(constant), we would + // long ago have constant-folded away the reverse. + // + // If the constant has any other uses, reversing it isn't entirely free, + // since we'd now have two constants to keep in memory. But hopefully it's + // free enough. + // + // TODO(jlebar): Should we do this even if the filter is not a constant? + // Reversing a non-constant filter is probably cheaper than padding the + // input! + + // Nothing to do, just fall through. } else { // Possibly 1x1 filter. for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { @@ -373,22 +401,25 @@ std::tuple MatchBackwardInput( } } - // Fuse the matched HLOs into a backward convolution instruction. - // - // If the reverse is omitted (for 1x1 filters) in the original pattern, we add - // it back in the fusion instruction so that later passes (such as - // PadInsertion) can handle such fusion instructions easily. + // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. + // This simplifies things for our caller, and algebraic-simplifier will later + // remove any unnecessary reverses. if (reverse_filter->opcode() != HloOpcode::kReverse) { - reverse_filter = reverse_filter->parent()->AddInstruction( + // Create a double-reverse, which is a nop. + HloComputation* c = conv->parent(); + reverse_filter = c->AddInstruction( + HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, + AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction( HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, AsInt64Slice(kernel_spatial_dims))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } + dnums.set_kernel_input_feature_dimension( conv->convolution_dimension_numbers().kernel_output_feature_dimension()); dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); } @@ -405,7 +436,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { if (match) { return CreateCudnnConvBackwardFilter( conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums); + window, dnums, conv->feature_group_count()); } std::tie(match, window, dnums) = MatchBackwardInput(conv); @@ -415,15 +446,17 @@ StatusOr RunOnInstruction(HloInstruction* conv) { CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput( - conv->shape(), conv->mutable_operand(0), rhs, window, dnums); + return CreateCudnnConvBackwardInput(conv->shape(), + conv->mutable_operand(0), rhs, window, + dnums, conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers()); + conv->convolution_dimension_numbers(), + conv->feature_group_count()); } return nullptr; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index 0c0578d88840fed1d77f7456c9acef27dec380f5..fbe7e9849458e9d52be15b3f5610479ab68ffa4c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -26,7 +26,7 @@ namespace gpu { // backwards-input convolutions into CustomCall HLOs that call into cuDNN. class CudnnConvolutionRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 65588b6aaf24da628ea586eb52c462b78b8daaa7..46c23db4652cccb06c9ca2a199a46ae04b332286 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,10 +32,13 @@ namespace gpu { namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::_; -class CudnnConvolutionRewriterTest : public HloTestBase { +class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { public: - CudnnConvolutionRewriterTest() { + CudnnConvolutionRewriterTest() + : HloVerifiedTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -114,7 +117,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -142,7 +145,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -172,7 +175,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -202,7 +205,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -230,7 +233,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -280,7 +283,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -325,7 +328,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -357,7 +360,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -410,7 +413,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -457,7 +460,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -510,7 +513,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -562,12 +565,38 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +// Check that we will materialize a reversed version of a constant in order to +// pattern-match a backwards input convolution. +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { + Array4D constant_arr(4, 4, 2, 2); + constant_arr.FillIota(0); + string constant_str = + LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + ParseAndVerifyModule(absl::StrFormat(R"( + HloModule test + + ENTRY entry_computation { + param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) + constant = f32[4,4,2,2]{3,2,1,0} constant(%s) + ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), + window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, + dim_labels=bf01_01oi->bf01, feature_group_count=1 + })", + constant_str)); + EXPECT_TRUE(RunPass(&module())); + EXPECT_THAT( + module().entry_computation()->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 0645fbb3ad39f1f1649caf45a6068b5a196c30b9..05125e9d1fb3cd03cb72b7854fc28c767b49fd64 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator { "Can't allocate twice from a ScratchBufAllocator."); } if (byte_size > scratch_.size()) { - return se::port::InternalError(tensorflow::strings::StrCat( + return se::port::InternalError(absl::StrCat( "Can't allocate ", byte_size, " bytes from a ScratchBufAllocator of size ", scratch_.size())); } @@ -76,8 +77,9 @@ Status RunCudnnConvolution( const Shape& output_shape, DeviceMemory input_buf, DeviceMemory filter_buf, DeviceMemory output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, - Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -96,15 +98,9 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - if (std::is_same::value) { - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else if (std::is_same::value) { - CHECK_EQ(F16, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); - } else { - LOG(FATAL) << ShapeUtil::HumanString(output_shape); - } + CHECK_EQ(primitive_util::NativeToPrimitiveType(), + output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); @@ -149,6 +145,7 @@ Status RunCudnnConvolution( } ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + convolution_descriptor.set_group_count(feature_group_count); for (int dim = 0; dim < num_dimensions; ++dim) { convolution_descriptor .set_zero_padding( @@ -202,8 +199,8 @@ Status RunCudnnConvolution( if (!stream->ok()) { return InternalError( - "Unable to launch convolution with type %s and algorithm (%lld, %lld)", - CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), algorithm.algorithm_no_scratch().algo_id()); } return Status::OK(); @@ -227,14 +224,14 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, + output_buf, &scratch_allocator, window, dnums, feature_group_count, + algorithm, stream, profile_result); } Status RunCudnnConvolution( @@ -242,25 +239,35 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); - CHECK(output_primitive_type == F32 || output_primitive_type == F16) - << ShapeUtil::HumanString(output_shape); - if (output_primitive_type == F32) { - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - algorithm, stream, profile_result); + switch (output_primitive_type) { + case F16: + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); + case F32: + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), scratch_allocator, window, dnums, + feature_group_count, algorithm, stream, profile_result); + case F64: + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), scratch_allocator, window, + dnums, feature_group_count, algorithm, stream, profile_result); + default: + LOG(FATAL) << ShapeUtil::HumanString(output_shape); } - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), - scratch_allocator, window, dnums, algorithm, - stream, profile_result); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 944e4ac686d45408b08ff1faa321510c1c8920ba..a1b4fc71d0cac3e5ea067ca7941b07cbade8d7cc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -75,7 +75,7 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); @@ -84,7 +84,7 @@ Status RunCudnnConvolution( const Shape& output_shape, se::DeviceMemoryBase input_buf, se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, se::dnn::AlgorithmConfig algorithm, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 9b6de115ad7e7f87e431f839c1690858f4bce3fd..c1aaa4bf04ddc31edf723c056805ae5aad994e55 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -43,16 +45,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gpu { +using absl::StrAppend; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrAppend; namespace { // Returns whether operand is a floating-point literal with the given value. @@ -74,10 +74,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( compute_nested_(std::move(compute_nested)) {} StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -94,7 +92,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -107,22 +105,20 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( break; default: return Unimplemented("Bad type for libdevice math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -138,22 +134,20 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( break; default: return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } return EmitMathCall(munged_callee, operands, input_types, output_type); } StatusOr GpuElementalIrEmitter::EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", - PrimitiveType_Name(input_type).c_str(), - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(input_type), + PrimitiveType_Name(output_type)); } } @@ -163,8 +157,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); @@ -183,8 +176,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( } StatusOr GpuElementalIrEmitter::EmitPowerOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -218,7 +210,7 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( // TODO(jlebar): Does this happen with fastmath disabled? If not, should // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -227,55 +219,56 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( } StatusOr GpuElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { + PrimitiveType prim_type, llvm::Value* value) { return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog1p( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitSin( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCos( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExp( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExpm1( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); } StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a @@ -285,17 +278,15 @@ StatusOr GpuElementalIrEmitter::EmitTanh( // Upcast F16 to F32 if necessary. llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* input = FPCast(value, type); llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); + return FPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const { + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type, + absl::Span attributes) { std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( @@ -315,29 +306,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return Call(callee, llvm_ir::AsArrayRef(operands)); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); +llvm::Value* GpuElementalIrEmitter::EmitThreadId() { + llvm::Value* block_id = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kMap: return [=, &operand_to_generator]( @@ -383,7 +373,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(init_value, accum_ptr); + Store(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -405,22 +395,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = b_->CreateNSWMul( + llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = b_->CreateNSWSub( - b_->CreateNSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + input_index[i] = + NSWSub(NSWAdd(stridden_index, window_index[i]), + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpULT( - input_index[i], - index_typed_const(operand->shape().dimensions(i)))); + in_bounds = + And(in_bounds, + ICmpULT(input_index[i], + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -432,12 +421,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b_->CreateLoad(accum_ptr), input_value})); - b_->CreateStore(accum_value, accum_ptr); + compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); + Store(accum_value, accum_ptr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return b_->CreateLoad(accum_ptr); + return Load(accum_ptr); }; case HloOpcode::kReduce: // TODO(b/112040122): This should be supported. diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 84454d31bb820a3de6ef3364bd205b8115bd95c0..e8b56a39ce58b6aab35c1c977553c7ff7e753273 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -38,9 +38,9 @@ namespace gpu { class GpuElementalIrEmitter : public ElementalIrEmitter { public: // A NestedComputer computes an element of the output of the given computation - // given an ArraySlice of its input elements. + // given a Span of its input elements. using NestedComputer = std::function( - const HloComputation&, tensorflow::gtl::ArraySlice)>; + const HloComputation&, absl::Span)>; GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config, llvm::Module* module, llvm::IRBuilder<>* b, @@ -48,85 +48,77 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; + StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) override; StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; - llvm::Value* EmitThreadId() const override; + llvm::Value* EmitThreadId() override; private: // Emits IR for op, which must have opcode kPower. StatusOr EmitPowerOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. llvm::Value* EmitDeviceFunctionCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_type, - PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const; + const string& callee_name, absl::Span operands, + absl::Span input_type, PrimitiveType output_type, + absl::Span attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLlvmIntrinsicMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. StatusOr EmitLibdeviceMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. StatusOr EmitMathCall( - const string& callee_name, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 0cdddf8bcfd4e849b311bf810eda471d79dbf106..ca4a605af5d3b6b58b603d7ddad60ed9ae8a212f 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -43,8 +43,8 @@ StatusOr> FftScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -92,8 +92,7 @@ string FftTypeToString(se::fft::Type type) { } // namespace -FftThunk::FftThunk(FftType fft_type, - tensorflow::gtl::ArraySlice fft_length, +FftThunk::FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, @@ -213,7 +212,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, - FftTypeToString(fft_type_).c_str()); + FftTypeToString(fft_type_)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 8c53be5077b0c5a88d303c729457139c6cb800f1..2be50e08bd2b561b44245b20e1fb200e31e65a41 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -62,7 +62,7 @@ class FftThunk : public Thunk { public: // Constructs a thunk for launching an FFT on a stream. // Semantics of null hlo_instruction argument are as in Thunk. - FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + FftThunk(FftType fft_type, absl::Span fft_length, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& output_buffer, const Shape& input_shape, const Shape& output_shape, diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 2fd2206324e5f763490780a54880825a772b7ea2..88f0b4d71c915c37f0b58cb91a8788fd8f9cc452 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3cd30b754c3242f00c704de1afab2282ed827b41..30c1f9088968305ad0207164ecb07ba13cc89ee6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -64,10 +66,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (absl::c_all_of( + instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -223,10 +226,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || - user->fusion_kind() == HloInstruction::FusionKind::kInput); + (user->fusion_kind() == HloInstruction::FusionKind::kInput && + LayoutsAreReduceInputFusionFriendly(*fusion, *user))); })) { VLOG(3) << "Not merging " << fusion->name() << ": Some of its users are not loop/input fusion kernels."; @@ -241,11 +245,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (c_any_of(fusion->fused_instructions(), - [](const HloInstruction* instruction) { - return instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction); - })) { + if (absl::c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { VLOG(3) << "Not merging " << fusion->name() << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; @@ -287,11 +291,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " - << tensorflow::str_util::Join(users, ", ", - [](string* out, HloInstruction* user) { - tensorflow::strings::StrAppend( - out, user->name()); - }) + << absl::StrJoin(users, ", ", + [](string* out, HloInstruction* user) { + absl::StrAppend(out, user->name()); + }) << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 4c523a66de977cd32423b25f0d165c4f4ba51c4a..7e3f5775b8d97f43a0bba201d24f34c2d337fabb 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -34,7 +34,7 @@ namespace gpu { // class FusionMerger : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "fusion merger"; } + absl::string_view name() const override { return "fusion merger"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index b22bb1d39ba177ef42673c7a3755694b43c15d14..7cc869ed9e89688d6ea06428a7bade3ebe55ea23 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -286,6 +286,39 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { op::Fusion(op::Parameter())); } +TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) { + auto module = ParseHloString(R"( + HloModule m + + f1_computation { + f1_p0 = f32[16,16,256]{0,1,2} parameter(0) + add = f32[16,16,256]{0,1,2} add(f1_p0, f1_p0) + // Note that the copy changes the layout from {0,1,2} to {2,1,0}. + ROOT f1_root = f32[16,16,256]{2,1,0} copy(add) + } + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + f2_computation { + f2_p0 = f32[16,16,256]{2,1,0} parameter(0) + f2_zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[16,16,256]{0,1,2} parameter(0) + f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation + ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + })") + .ValueOrDie(); + EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 74282c568c09921dbeec2e9cce79b6c73b6ea592..9c4a4903667ea1a6c99ce9e912c9d0497b8e389f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -186,7 +186,7 @@ StatusOr DoGemmAutotune( } return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " "ran successfully", stream, algorithms.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 0c6f9b511f3aac5f62182273b827adcd068cd633..8ffae18fe820aa01701731ee56a83aeacf0eab0d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -27,7 +27,7 @@ namespace gpu { // inserting kCopy instructions. class GpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 70608379048871cf6ee72145fa9afff71a3eabe6..31a9f9b1beb81da81a06f6dc8e7c13c105514092 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks( // // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), // since we expect it to be an expensive call? - tensorflow::gtl::optional op_annotation; + absl::optional op_annotation; if (top_level_annotation.IsEnabled()) { op_annotation.emplace( thunk->hlo_instruction() != nullptr @@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR( thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { - auto finish_event = MakeUnique(main_stream->parent()); + auto finish_event = absl::make_unique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); @@ -160,7 +160,7 @@ Status GpuExecutable::ExecuteThunks( if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", - main_stream, block_status.error_message().c_str()); + main_stream, block_status.error_message()); } } @@ -234,7 +234,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); @@ -260,10 +260,9 @@ StatusOr GpuExecutable::ExecuteOnStream( if (buffer.is_null() && buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " - "index %s of parameter %lld was null. All pointers to " - "(sub-)buffers must not be null, unless the (sub-)buffer has zero " - "elements.", - allocation.param_shape_index().ToString().c_str(), param_no); + "index %s of parameter %d was null. All pointers to (sub-)buffers " + "must not be null, unless the (sub-)buffer has zero elements.", + allocation.param_shape_index().ToString(), param_no); } buffer_allocations_builder.RegisterBuffer(i, buffer); @@ -326,7 +325,7 @@ StatusOr GpuExecutable::ExecuteOnStream( StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index c7ce6d0acbbbe594040271c0d45c71c016e36514..38b0f8f15bd28cf2659e4a53b6634e981545716b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -32,10 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -78,12 +78,12 @@ class GpuExecutable : public Executable { // match the compute capability passed to this object's constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; private: // If `block_host_until_done` is false, execution will not block the host diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d31fd5570c468b0c42fa308535fd335f3588a79 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { + +namespace { +void AppendParams(const HloInstruction& instr, + std::vector* params) { + if (instr.opcode() == HloOpcode::kFusion) { + params->insert(std::end(*params), std::begin(instr.fused_parameters()), + std::end(instr.fused_parameters())); + } else { + for (HloInstruction* operand : instr.operands()) { + params->push_back(operand); + } + } +} +} // namespace + +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce) { + std::vector params; + AppendParams(producer, ¶ms); + AppendParams(reduce, ¶ms); + int64 max_rank = -1; + const Layout* max_rank_layout; + for (HloInstruction* param : params) { + if (ShapeUtil::IsArray(param->shape()) && + ShapeUtil::Rank(param->shape()) > max_rank) { + max_rank = ShapeUtil::Rank(param->shape()); + max_rank_layout = ¶m->shape().layout(); + } + } + return absl::c_all_of(params, [&](HloInstruction* param) { + return (!ShapeUtil::IsArray(param->shape())) || + (ShapeUtil::Rank(param->shape()) < max_rank) || + (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); + }); +} + +bool IsInputFusibleReduction(const HloInstruction& instr) { + if (instr.IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr.fused_expression_root()->operands()) { + if (IsReductionToVector(*operand)) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Multi-output fusion rooted at reduction-to-vector ops must be " + "of kind kInput: " + << instr.ToString(); + return true; + } + } + return false; + } else if (instr.opcode() == HloOpcode::kFusion) { + if (IsReductionToVector(*instr.fused_expression_root())) { + CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput) + << " Fusion rooted at reduction-to-vector op must be of kind kInput: " + << instr.ToString(); + return true; + } + return false; + } + return IsReductionToVector(instr); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h new file mode 100644 index 0000000000000000000000000000000000000000..f7c24a0d5bbfcc61389ea19ae7f769671e4e974d --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +// TODO(b/112957171): Extract logic to determine fusibility of HLO ops from +// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion. + +namespace xla { +namespace gpu { + +// The code emitted for reduce-rooted input fusions (EmitReductionToVector) +// suffers from poor data locality if the layouts of input parameters differ. In +// such situtations it is better not to fuse. Only input params with +// maximum rank are considered. Params with smaller ranks will be broadcasted +// and have not been observed to cause data locality issues. +// TODO(b/111977086): Improve reduce emitters to remove this limitation. +bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, + const HloInstruction& reduce); + +// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` +// is either an unfused reduction-to-vector op, an input fusion rooted at a +// reduction-to-vector op, or a multi-output input fusion with at least one +// reduction-to-vector op root. +// Note that reduction ops are lowered in different ways. Reduce input fusions +// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at +// reduction-to-vector ops. Other reduction ops are lowered by +// GpuElementalIrEmitter and fused like elementwise ops. +bool IsInputFusibleReduction(const HloInstruction& instr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_FUSIBLE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91b7bc61fda5a07c163a07ec0e1644d2ad9db49 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { + +using GpuFusibleTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + scalar_add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + })"; + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ElementwiseProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + exp = f32[2,2,2]{2,1,0} exponential(p0) + ROOT reduce = f32[2,2]{1,0} reduce(exp, c0), dimensions={2}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* exp = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(exp->opcode(), HloOpcode::kExp); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*exp, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_MixedLayoutProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + mixed_input_layouts_computation { + p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + copy = f16[128,1024,32,32]{1,3,2,0} copy(p1.1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{1,3,2,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{1,3,2,0} greater-than(copy, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation + reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + ROOT root = (f32[1024]{0}, f16[128,1024,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion) + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(1); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduce { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + c0.1 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(p0.1, c0.1), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + copy = f32[128,1024,32,32]{1,3,2,0} copy(p0) + ROOT reduce_fusion = f32[1024]{0} fusion(copy), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->fused_expression_root()->opcode(), HloOpcode::kReduce); + const HloInstruction* copy = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(copy->opcode(), HloOpcode::kCopy); + EXPECT_FALSE(LayoutsAreReduceInputFusionFriendly(*copy, *reduce)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_LayoutChangingFusionProducer) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + layout_changing_computation { + p0.1 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + c0 = f16[] constant(0) + broadcast = f16[128,1024,32,32]{3,2,1,0} broadcast(c0), dimensions={} + greater-than = pred[128,1024,32,32]{3,2,1,0} greater-than(p1.1, broadcast) + select = f16[128,1024,32,32]{3,2,1,0} select(greater-than, p0.1, broadcast) + ROOT root = f16[128,1024,32,32]{1,3,2,0} copy(select) + } + fused_reduce { + p0.2 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + convert = f32[128,1024,32,32]{1,3,2,0} convert(p0.2) + c0.2 = f32[] constant(0) + ROOT reduce = f32[1024]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f16[128,1024,32,32]{3,2,1,0} parameter(0) + p1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) + loop_fusion = f16[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=layout_changing_computation + ROOT reduce_fusion = f32[1024]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce_fusion = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(), + HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kCopy); + EXPECT_FALSE( + LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion)); +} + +TEST_F(GpuFusibleTest, + LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + broadcasting_computation { + p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0) + p1.1 = f32[128]{0} parameter(1) + broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0} + ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast) + } + ENTRY entry { + p0 = f16[128,1024,32,32]{1,3,2,0} parameter(0) + p1 = f16[128]{0} parameter(1) + loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation + c0.2 = f32[] constant(0) + ROOT reduce = f32[128,1024]{0,1} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + const HloInstruction* loop_fusion = + module->entry_computation()->root_instruction()->operand(0); + ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kAdd); + EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + // Reduction-to-vector lowered by IrEmitterUnnested. + ROOT reduce = f32[512]{0} reduce(p1, c0), dimensions={0,2,3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + ENTRY entry { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + // Reduction lowered by GpuElementalIrEmitter. + ROOT reduce = f32[8,512,5,1,1]{4,3,2,1,0} reduce(p1, c0), dimensions={3}, to_apply=scalar_add + })")) + .ValueOrDie(); + SCOPED_TRACE(module->ToString()); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + ROOT reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = f32[128,512]{1,0} fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(1) + ROOT reduce = f32[8,5,1,1]{3,2,1,0} reduce(p1, c0), dimensions={1,3}, to_apply=scalar_add + } + ENTRY entry { + p0 = f32[8,512,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT fusion = f32[8,5,1,1]{3,2,1,0} fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + reduce.1 = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + ROOT root = (f32[128,512]{1,0}, f32[128,512]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512]{1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[128,512]{1,0} reduce(p1, c0), dimensions={2,3}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[128,512]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kInput, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_TRUE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce.0 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + reduce.1 = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + ROOT root = (f32[512,28]{1,0}, f32[512,28]{1,0}) tuple(reduce.0, reduce.1) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[512,28]{1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +TEST_F(GpuFusibleTest, + IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( + fused_reduction { + c0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + reduce = f32[512,28]{1,0} reduce(p1, c0), dimensions={0,2}, to_apply=scalar_add + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1, p1) + ROOT root = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) tuple(reduce, mul) + } + ENTRY entry { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + ROOT fusion = (f32[512,28]{1,0}, f32[128,512,28,28]{3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_reduction + })")) + .ValueOrDie(); + const HloInstruction* reduce = + module->entry_computation()->root_instruction(); + ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion); + EXPECT_FALSE(IsInputFusibleReduction(*reduce)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc similarity index 94% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 19de37b0fbed15455e8c6a9bfe427ba3d9f0a9dc..743035a84eaeb41fafb336844a1a7a07b82af4db 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" @@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = - MakeUnique>(thunk_launch_order); + entry_sequence_ = absl::make_unique>( + thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering( // same-stream predecessors of each instruction. // Compute the set of all instructions we will want to set reachability on. - auto predecessor_map = MakeUnique( + auto predecessor_map = absl::make_unique( module->entry_computation()->MakeInstructionPostOrder()); // The most recently visited instruction per stream. @@ -184,13 +184,13 @@ void BFSLaunchOrder(const HloComputation* computation, } // end namespace -HloSchedule::HloSchedule() {} +GpuHloSchedule::GpuHloSchedule() {} /* static */ -StatusOr> HloSchedule::Build( +StatusOr> GpuHloSchedule::Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size) { - std::unique_ptr schedule(new HloSchedule); + std::unique_ptr schedule(new GpuHloSchedule); // Initialize thunk_launch_order_, the total order of thunk launches. const HloComputation* entry_computation = module.entry_computation(); @@ -208,7 +208,7 @@ StatusOr> HloSchedule::Build( BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } - schedule->hlo_ordering_ = MakeUnique( + schedule->hlo_ordering_ = absl::make_unique( &module, stream_assignment, schedule->thunk_launch_order_); return std::move(schedule); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h similarity index 84% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule.h rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 1ce7a48ac8fcbbad0b3697845681582fe806b322..30a0e7cecd202e83898d34e00b5b49684d1b1b68 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ #include #include @@ -34,11 +34,11 @@ namespace gpu { // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk // launch order. -class HloSchedule { +class GpuHloSchedule { public: - // Constructs an HloSchedule for the given module, based on the given stream - // assignment. - static StatusOr> Build( + // Constructs an GpuHloSchedule for the given module, based on the given + // stream assignment. + static StatusOr> Build( const HloModule& module, const StreamAssignment& stream_assignment, int64 pointer_size); @@ -56,7 +56,7 @@ class HloSchedule { } private: - HloSchedule(); + GpuHloSchedule(); std::vector thunk_launch_order_; std::unique_ptr hlo_ordering_; @@ -65,4 +65,4 @@ class HloSchedule { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc similarity index 95% rename from tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 45f0a1c645b2875cf90d2c11cfb66c3dd855d097..0922e44a126eadab17d60d9ece53aae8d8f1c218 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,16 +30,16 @@ limitations under the License. namespace xla { namespace gpu { -class HloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloTestBase { protected: using HloVec = std::vector; // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); - static std::unique_ptr BuildHloSchedule( + static std::unique_ptr BuildGpuHloSchedule( const HloModule& module, const StreamAssignment& streams) { - return HloSchedule::Build(module, streams, /*pointer_size=*/8) + return GpuHloSchedule::Build(module, streams, /*pointer_size=*/8) .ConsumeValueOrDie(); } @@ -47,7 +48,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } HloVec RemoveHlo(const HloVec& input, @@ -64,7 +65,7 @@ class HloScheduleTest : public HloTestBase { // Test of a single stream, where data dependencies fully determine the // execution order. -TEST_F(HloScheduleTest, SequentialMatMul) { +TEST_F(GpuHloScheduleTest, SequentialMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -84,7 +85,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { EXPECT_EQ(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({dot1, dot2})); @@ -122,7 +123,7 @@ TEST_F(HloScheduleTest, SequentialMatMul) { // Test of a single stream, where data dependencies do not fully determine the // execution order, but the stream assignment does. -TEST_F(HloScheduleTest, SequentialAdd) { +TEST_F(GpuHloScheduleTest, SequentialAdd) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -146,7 +147,7 @@ TEST_F(HloScheduleTest, SequentialAdd) { EXPECT_EQ(streams->StreamNumberForHlo(*add1), streams->StreamNumberForHlo(*add3)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. EXPECT_EQ(RemoveHlo(schedule->ThunkLaunchOrder(), {x, y, z}), HloVec({add1, add2, add3})); @@ -194,7 +195,7 @@ TEST_F(HloScheduleTest, SequentialAdd) { } // Test of two streams. -TEST_F(HloScheduleTest, ConcurrentMatMul) { +TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { HloComputation::Builder builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); @@ -214,7 +215,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { EXPECT_NE(streams->StreamNumberForHlo(*dot1), streams->StreamNumberForHlo(*dot2)); - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); // Remove parameters, which are unordered. HloVec thunk_launch_order = RemoveHlo(schedule->ThunkLaunchOrder(), {x, y}); EXPECT_TRUE(thunk_launch_order == HloVec({dot1, dot2, add}) || @@ -250,7 +251,7 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { } // Test of multiple streams. -TEST_F(HloScheduleTest, LatticeMatMul) { +TEST_F(GpuHloScheduleTest, LatticeMatMul) { // d00 -- layer 0 // / \ // d10 d11 -- layer 1 @@ -265,7 +266,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); @@ -306,7 +307,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { // We don't check the thunk launch order, since there are many valid total // orders, and it's annoying to express. - auto schedule = BuildHloSchedule(*module, *streams); + auto schedule = BuildGpuHloSchedule(*module, *streams); auto order = schedule->ConsumeHloOrdering(); const HloVec all_params( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4944c41f7d8dc7a78a3cd094aee4d7087c74857e..4268fb2c7a813b3b53e4cd48746028a7b369f28e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr GpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "GPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d63e213d2b1efab4bcff75541cc5ab33d7a07976..bbb3340760c8330bd6570f33382f004315c6d0bd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface { GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "gpu_hlo_support_checker"; - } + absl::string_view name() const override { return "gpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 286547ebae2f1a4b8d783a06d13b4dd96052b952..fbc8ddf599570b90e93eb463a1fd6c275b73711c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { // Enumerate all combinations of shapes. for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { for (int constrained_param_no : {0, 4}) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index a2f53f844613da9fe8166489dc9959e8d30c6332..f3c274429242d5c989146d14ea523b5910408cff 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -83,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } infeed_manager->EnqueueDestination(std::move(buffers)); @@ -96,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size == 0) { @@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( if (ShapeUtil::IsTuple(shape)) { return; } - *buffer = MakeUnique(GetByteSizeRequirement(shape)); + *buffer = absl::make_unique( + GetByteSizeRequirement(shape)); (*buffer)->set_destination( - MakeUnique(literal, index)); + absl::make_unique(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr CreateNVPTXTransferManager() { - return xla::MakeUnique( + return absl::make_unique( /*id=*/stream_executor::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 7929042869763dfeab2fe8f87093b7ea758337d0..fa88816bc8b0bf41f05358c0089b381305ed3182 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ #include @@ -61,4 +61,4 @@ class GpuTransferManager : public GenericTransferManager { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 17226769302eef0dd01550b0bc5404e889ad78f8..b9c21e8edb2bdde03acb1fe6197a399724c9c8ab 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,7 @@ namespace gpu { namespace { void InitAndStartTimer(std::stack>* timers, se::Stream* stream) { - timers->push(MakeUnique(stream->parent())); + timers->push(absl::make_unique(stream->parent())); stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } @@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler( CHECK(hlo_instructions_.insert(hlo_instruction).second) << hlo_instruction->name(); } - return MakeUnique(this, hlo_instruction); + return absl::make_unique(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 8c11cd05419289d82b033c936bb60884f45cb636..51627402b45f594dab3480129ba182d54d01b811 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -24,20 +25,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos) { + absl::Span io_hlos, + absl::Span non_io_hlos) { // I/O HLOs are bound to the arguments of the current IR function. I.e., // // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index eee40b0e91fc03013a6978ae3cfe42b87633eed7..c0edae530cedba45c897b07b7b9cc72eaaab397c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace gpu { @@ -45,8 +45,8 @@ class HloToIrBindings { alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} void EmitBasePointersForHlos( - tensorflow::gtl::ArraySlice io_hlos, - tensorflow::gtl::ArraySlice non_io_hlos); + absl::Span io_hlos, + absl::Span non_io_hlos); // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index c5f0cdf6cd5d3e076bffa875fbba991bf0681ee8..a4364b0deb6c97b7b580e18bf67d5f3a8fd3cc62 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" namespace xla { namespace gpu { @@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; - host_to_device_stream_ = MakeUnique(executor); + host_to_device_stream_ = absl::make_unique(executor); host_to_device_stream_->Init(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index fee6d2af3bfd4976f5845edf592e8310b55a3feb..8c3a026740851767855beae59d6a3c92f7a0d6bd 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 0f2c83aeb2633a007559d8caac78ea2d233539ed..4d5d8e99f88149aabfd0a4aeafc7e6724d29418d 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -26,7 +27,7 @@ namespace gpu { namespace { -bool IsFusile(const HloInstruction& hlo) { +bool IsFusible(const HloInstruction& hlo) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an // unfused GTE is an input to a kernel (including a fusion kernel), we @@ -41,7 +42,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || hlo.opcode() == HloOpcode::kGather || - hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kIota || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || @@ -221,6 +222,13 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Do not fuse into reduce input fusions if the resulting kernel would suffer + // from poor data locality (due to unfriendly input layouts). + if (IsInputFusibleReduction(*consumer) && + !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { + return false; + } + // We can't fuse library calls, so if a user of such an op could become a // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for // further rationale. @@ -245,7 +253,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (!IsFusile(*producer) || !IsFusile(*consumer) || + if (!IsFusible(*producer) || !IsFusible(*consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8..bca775c4750dd3aa679846d54e29a9d277adad79 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -171,6 +171,78 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduce) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + constant.1 = f32[] constant(0) + ROOT reduce = f32[16] reduce(copy, constant.1), dimensions={0,1,2}, to_apply=add + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DoNotFuseLayoutChangingOpWithReduceFusion) { + auto module = ParseHloString(R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + fused_reduce { + p0.1 = f32[16,16,16,16]{0,1,2,3} parameter(0) + mul = f32[16,16,16,16]{0,1,2,3} multiply(p0.1, p0.1) + c0.1 = f32[] constant(0) + ROOT root = f32[] reduce(mul, c0.1), dimensions={0,1,2,3}, to_apply=add + } + + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + fusion = f32[] fusion(copy), kind=kInput, calls=fused_reduce + ROOT root = (f32[]) tuple(fusion) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, FuseLayoutChangingOpWithElementwise) { + auto module = ParseHloString(R"( + HloModule test_module + ENTRY entry { + p0 = f32[16,16,16,16]{3,2,1,0} parameter(0) + copy = f32[16,16,16,16]{0,1,2,3} copy(p0) + ROOT add = f32[16,16,16,16]{0,1,2,3} add(copy, copy) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), op::Add(op::Copy(), op::Copy())); +} + TEST_F(InstructionFusionTest, BitcastIntoAdd) { auto module = ParseHloString(R"( HloModule test_module @@ -365,7 +437,7 @@ static StatusOr FindHloInstruction( } return NotFound( "Computation '%s' does not contain an instruction with op code '%s'.", - computation.name().c_str(), HloOpcodeString(op).c_str()); + computation.name(), HloOpcodeString(op)); } TEST_F(InstructionFusionTest, MultiOutputFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index c349063c71f000435a05306101ad724505f2d197..20d523abe0552f0bc22c365007c096666ec888f6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -144,10 +144,12 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv( - const char* call_target, const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { +static HloInstruction* CreateCudnnConv(const char* call_target, + const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { HloComputation* computation = lhs->parent(); // This call returns a tuple of (conv_result, scratch_memory), where @@ -165,28 +167,34 @@ static HloInstruction* CreateCudnnConv( HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); return custom_call; } -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums) { +HloInstruction* CreateCudnnConvForward(const Shape& shape, + HloInstruction* input, + HloInstruction* kernel, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums); + window, dnums, feature_group_count); } HloInstruction* CreateCudnnConvBackwardInput( const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums); + reverse_filter, window, dnums, feature_group_count); } HloInstruction* CreateCudnnConvBackwardFilter( const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums) { + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums); + output, window, dnums, feature_group_count); } bool IsReductionToVector(const HloInstruction& reduce) { @@ -215,8 +223,8 @@ bool IsReductionToVector(const HloInstruction& reduce) { // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, - tensorflow::gtl::ArraySlice arguments, +llvm::Value* EmitPrintf(absl::string_view fmt, + absl::Span arguments, llvm::IRBuilder<>* builder) { std::vector argument_types; for (auto argument : arguments) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 5d23a3d01842c7b4ff405171cd49c96a19f7e5b0..59c65fc2686cd4a00a3770ebaedf637e8f556828 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -109,15 +109,20 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); // // The created cudnn call will use the default cudnn algorithm and no scratch // space. -HloInstruction* CreateCudnnConvForward( - const Shape& shape, HloInstruction* input, HloInstruction* kernel, - const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvForward(const Shape& shape, + HloInstruction* input, + HloInstruction* kernel, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); HloInstruction* CreateCudnnConvBackwardInput( const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums); + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); HloInstruction* CreateCudnnConvBackwardFilter( const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums); + const Window& window, const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count); // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. @@ -126,8 +131,8 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, - tensorflow::gtl::ArraySlice arguments, +llvm::Value* EmitPrintf(absl::string_view fmt, + absl::Span arguments, llvm::IRBuilder<>* builder); // Emits code to shuffle data between threads of a warp. This has the same diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6675dbd3f9eef8d13c9dec200e5bf47faa5b514d..ffca5d6549a8316a7c7b7946d9943f091c133d1b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -140,7 +141,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { Status IrEmitter::EmitCallToNestedComputation( const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output) { + absl::Span operands, llvm::Value* output) { TF_RET_CHECK(nested_computation.num_parameters() > 0); llvm::Function*& emitted_function = computation_to_ir_function_[&nested_computation]; @@ -155,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - b_.CreateCall(emitted_function, arguments); + Call(emitted_function, arguments); return Status::OK(); } @@ -177,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = b_.CreateLoad(source_address, "source"); + llvm::Value* source = Load(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -189,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } if (is_atomic_integral) { // integral + integral - b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -201,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -211,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -291,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -308,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = PtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = And(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = And(atomic_memory_address, mask); atomic_memory_address = - b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = b_.CreateAdd( - b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); + IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = - b_.CreateIntToPtr(binop_output_address, element_address_type); + Add(PtrToInt(cas_new_output_address, address_int_type), offset); + binop_output_address = IntToPtr(binop_output_address, element_address_type); } else { - atomic_memory_address = - b_.CreateBitCast(output_address, atomic_address_type); + atomic_memory_address = BitCast(output_address, atomic_address_type); binop_output_address = - b_.CreateBitCast(cas_new_output_address, element_address_type); + BitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. - llvm::Value* cas_old_output = - b_.CreateLoad(atomic_memory_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_old_output_address); + llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); + Store(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( b_.GetInsertPoint(), "atomic_op_loop_exit"); @@ -343,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = Load(cas_old_output_address, "cas_old_output"); + Store(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); - llvm::Value* cas_new_output = - b_.CreateLoad(cas_new_output_address, "cas_new_output"); + llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = b_.CreateAtomicCmpXchg( - atomic_memory_address, cas_old_output, cas_new_output, - llvm::AtomicOrdering::SequentiallyConsistent, - llvm::AtomicOrdering::SequentiallyConsistent); + llvm::Value* ret_value = + AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. @@ -383,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( // TODO(b/30258929): We only accept binary computations so far. return Unimplemented( "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); + "computation %s has %d.", + computation.name(), computation.num_parameters()); } if (MaybeEmitDirectAtomicOperation(computation, output_address, @@ -471,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = b_.CreateInsertValue(result, value.first, {0}); - result = b_.CreateInsertValue(result, value.second, {1}); + result = InsertValue(result, value.first, {0}); + result = InsertValue(result, value.second, {1}); } else { - result = b_.CreateFMul(lhs_value, rhs_value); + result = FMul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -518,7 +513,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { - DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); rhs_index[i] = lhs_index[i]; } @@ -558,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = b_.CreateLoad(accum_address); + llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_element, rhs_element, &b_); llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); - updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* real_sum = FAdd(accum_real, value.first); + updated_accum = InsertValue(accum, real_sum, {0}); llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); - updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); + llvm::Value* imag_sum = FAdd(accum_imag, value.second); + updated_accum = InsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); - updated_accum = b_.CreateFAdd(accum, product); + llvm::Value* product = FMul(lhs_element, rhs_element); + updated_accum = FAdd(accum, product); } - b_.CreateStore(updated_accum, accum_address); + Store(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -594,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - b_.CreateLoad(accum_address), // The value written to the target array. + Load(accum_address), // The value written to the target array. &b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -638,17 +633,16 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { } auto arg = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + Alloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -685,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { *function, {accumulator_addr, input_address}, accumulator_addr)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); }); } @@ -752,14 +746,9 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr IrEmitter::ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements) { + absl::Span parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( computation.root_instruction()->shape().element_type(), module_), @@ -768,11 +757,11 @@ StatusOr IrEmitter::ComputeNestedElement( for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); - b_.CreateStore(parameter_element, parameter_buffers.back()); + Store(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return b_.CreateLoad(return_buffer); + return Load(return_buffer); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 561c6838798aa92ce2c96b3c45d5ba42fe6edef3..579268f07185fd2d8ec74750f1bf833101149437 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -35,13 +37,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -64,7 +65,8 @@ namespace gpu { // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR // generator generator. See comments on that class. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + llvm::IRBuilder<>* builder() { return &b_; } + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter @@ -140,9 +143,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Emits a call in IR to the given nested computation with the given operands // and output. If no IR function has been previously emitted for the // computation, also emits such a function. - Status EmitCallToNestedComputation( - const HloComputation& nested_computation, - tensorflow::gtl::ArraySlice operands, llvm::Value* output); + Status EmitCallToNestedComputation(const HloComputation& nested_computation, + absl::Span operands, + llvm::Value* output); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` @@ -196,7 +199,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { StatusOr ComputeNestedElement( const HloComputation& computation, - tensorflow::gtl::ArraySlice parameter_elements); + absl::Span parameter_elements); // Emits an atomic operation that implements `nested_computation` in the // sequentially consistent memory model. `output_address` and `source_address` diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1e81cbde35372d9f7d6ee234d2408038d6f99dc7..389a98facb9b553a91342bb7fc42642179aaf698 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -29,7 +35,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -76,8 +81,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -85,13 +88,12 @@ namespace gpu { namespace { +using absl::InlinedVector; +using absl::nullopt; +using absl::optional; +using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; -using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::InlinedVector; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; -using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -173,7 +175,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args) { + absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -314,13 +316,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, }; // Check the size of input tensors - if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { return i64_ty; } // Check the size of the internal result tensors if (unnested_hlo->opcode() == HloOpcode::kFusion) { - if (!c_all_of( + if (!absl::c_all_of( unnested_hlo->fused_instructions_computation()->instructions(), hlo_shape_in_range)) { return i64_ty; @@ -383,7 +385,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get({}); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -413,7 +415,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique( + absl::make_unique( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -443,19 +445,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -475,7 +478,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kForward, /*input_buffer=*/lhs_slice, /*filter_buffer=*/rhs_slice, @@ -486,10 +489,10 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardInput, /*input_buffer=*/conv_result_slice, /*filter_buffer=*/rhs_slice, @@ -500,10 +503,10 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique( + thunk = absl::make_unique( CudnnConvKind::kBackwardFilter, /*input_buffer=*/lhs_slice, /*filter_buffer=*/conv_result_slice, @@ -514,8 +517,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - backend_config.algorithm(), backend_config.tensor_ops_enabled(), - custom_call); + custom_call->feature_group_count(), backend_config.algorithm(), + backend_config.tensor_ops_enabled(), custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -552,10 +555,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { } VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - ArraySlice output_instructions = + absl::Span output_instructions = root->opcode() == HloOpcode::kTuple ? root->operands() - : ArraySlice(&root, 1); + : absl::Span(&root, 1); // For multi-output fusion emit an initializer for each tuple element. // Otherwise it's sufficient to just initialize the single output. @@ -576,7 +579,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), fusion)); + absl::make_unique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -714,8 +717,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::EmitExtraOutputsForReduce( const HloInstruction* reduce, const IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens) { for (int i = 0; i != extra_output_gens.size(); ++i) { const HloInstruction* output = reduce->parent()->FusionInstruction(); @@ -725,19 +727,18 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - b_.CreateStore(extra_output_ir_value, extra_output_address); + Store(extra_output_ir_value, extra_output_address); } return Status::OK(); } Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; @@ -798,8 +799,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), // // // // and threads_per_block is a multiple of warpSize. - // reduce_kernel<<>>(); - // + // reduce_kernel // auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = @@ -807,17 +807,17 @@ Status IrEmitterUnnested::EmitReductionToScalar( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { @@ -829,15 +829,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. @@ -846,11 +845,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( IrArray::Index input_index( /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -861,14 +860,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileSize), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); @@ -889,20 +888,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -917,10 +914,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = - b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { @@ -952,12 +948,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // Divide the input matrix into tiles of size KxL. For example, when the // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like @@ -1040,12 +1035,12 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1056,8 +1051,8 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); auto emit_tile_element_loop = [=](bool tile_in_y_bounds, bool tile_in_x_bounds) -> Status { @@ -1069,34 +1064,32 @@ Status IrEmitterUnnested::EmitColumnReduction( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* y = b_.CreateNSWAdd( - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); // Unless we know that y is in bounds, we have to emit a check before // reading from the input. if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", - &b_); + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); // Unless we know that x is in bounds, we have to emit a check before // reading from the input. if (!tile_in_x_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1123,7 +1116,7 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i * kTileWidth + x_offset], @@ -1138,20 +1131,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location // that's immediately beyond the tile. - llvm::Value* y_end = b_.CreateNSWAdd( - index_typed_constant(kTileHeight), - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location // that's immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileWidth), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); llvm::Value* tile_in_y_bounds = - b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); llvm::Value* tile_in_x_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. llvm_ir::LlvmIfData if_tile_in_y_bounds_data = @@ -1185,9 +1178,9 @@ Status IrEmitterUnnested::EmitColumnReduction( reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( @@ -1243,12 +1236,11 @@ static std::pair ComputeTilingSchemeForReduction( Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // A naive algorithm is: // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. @@ -1376,11 +1368,11 @@ Status IrEmitterUnnested::EmitRowReduction( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1389,22 +1381,20 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = ZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = b_.CreateNSWAdd( + llvm::Value* last_x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - index_typed_constant(x_tile_size - 1), - b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( &b_, @@ -1416,9 +1406,8 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = b_.CreateNSWAdd( - z_indvar, - b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", /*start=*/index_typed_constant(0), @@ -1426,22 +1415,20 @@ Status IrEmitterUnnested::EmitRowReduction( /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = b_.CreateNSWAdd( + llvm::Value* x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - x_indvar, b_.CreateNSWMul( - warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), - "x_in_bounds", &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &b_); @@ -1449,7 +1436,7 @@ Status IrEmitterUnnested::EmitRowReduction( // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1480,7 +1467,7 @@ Status IrEmitterUnnested::EmitRowReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1500,8 +1487,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; llvm::Value* tile_in_bounds = - b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - b_.CreateICmpULT(last_x, index_typed_constant(width))); + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1529,20 +1516,18 @@ Status IrEmitterUnnested::EmitRowReduction( for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1557,8 +1542,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = @@ -1604,13 +1588,12 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for @@ -1705,7 +1688,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } auto input = reduce->operand(0); auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); + absl::Span dimensions_to_reduce(reduce->dimensions()); HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that @@ -1718,7 +1701,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), reduce)); + absl::make_unique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1738,7 +1721,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = - c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { + absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() .GetUniqueTopLevelSlice(tuple_element) .ok(); @@ -1760,7 +1743,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique( + thunk_sequence_->emplace_back(absl::make_unique( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1792,8 +1775,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1842,7 +1825,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, @@ -1863,15 +1846,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( + llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = b_.CreateNSWSub( - b_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_constant(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( + operand_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ICmpULT( operand_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -1881,7 +1864,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -1889,16 +1872,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently @@ -1914,11 +1897,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter( TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = b_.CreateLoad(select_return_buffer); + llvm::Value* result = Load(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), @@ -1927,7 +1910,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -1939,8 +1922,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) @@ -2018,7 +2001,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), rng)); + absl::make_unique(std::move(thunks), rng)); return Status::OK(); } @@ -2043,7 +2026,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { auto values_destination = GetAllocationSlice(*sort, values_shape_index); if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*keys), /*destination_buffer=*/keys_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); @@ -2051,7 +2034,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (values != nullptr && values_destination != GetAllocationSlice(*values)) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*values), /*destination_buffer=*/values_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); @@ -2095,15 +2078,15 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? tensorflow::gtl::make_optional( + values != nullptr ? absl::make_optional( GetIrArray(*sort, *sort, values_shape_index)) - : tensorflow::gtl::nullopt, + : absl::nullopt, IrName(sort), xor_mask, &b_, &launch_dimensions)); } } thunk_sequence_->emplace_back( - MakeUnique(std::move(thunks), sort)); + absl::make_unique(std::move(thunks), sort)); return Status::OK(); } @@ -2130,7 +2113,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique( + thunk_sequence_->push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2145,17 +2128,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique( + thunks.push_back(absl::make_unique( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique(std::move(thunks), crs)); + absl::make_unique(std::move(thunks), crs)); return Status::OK(); } @@ -2305,7 +2288,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( for (const auto& kv : hlo_slices) { buffers_needed.insert(kv.second.first.allocation()); } - tensorflow::gtl::optional temp_buffer; + absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { if (alloc.IsPreallocatedTempBuffer()) { if (!temp_buffer.has_value()) { @@ -2322,10 +2305,10 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. std::vector non_constant_buffers; - c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), - [](const BufferAllocation* allocation) { - return !allocation->is_constant(); - }); + absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { @@ -2364,8 +2347,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( *slice.allocation()))); CHECK_NE(loc, nullptr); } else { - loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -2373,8 +2356,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Type* int8_double_pointer = llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = b_.CreateBitCast(loc, int8_double_pointer); - loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); + loc = BitCast(loc, int8_double_pointer); + loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2389,7 +2372,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique( + return absl::make_unique( non_constant_buffers, llvm_ir::AsString(kernel->getName()), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2398,7 +2381,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return absl::make_unique( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2410,7 +2393,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique( + return absl::make_unique( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2430,7 +2413,7 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique(slices, inst); + return absl::make_unique(slices, inst); } std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( @@ -2447,7 +2430,7 @@ std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique(std::move(slices), inst); + return absl::make_unique(std::move(slices), inst); } namespace { @@ -2470,7 +2453,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2512,7 +2495,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2529,11 +2512,12 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2580,11 +2564,11 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // Are all the bytes of this scalar equal to 0? If so, we can create a // MemzeroThunk. - ArraySlice literal_bytes( + absl::Span literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); - if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique(GetAllocationSlice(*hlo, index), nullptr)}; + if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {absl::make_unique(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2601,7 +2585,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique( + return {absl::make_unique( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2612,7 +2596,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique( + return {absl::make_unique( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2670,8 +2654,7 @@ Status CheckHloBuffersShareAllocation( if (slice_a != slice_b) { return InternalError( "instruction %s %s does not share allocation with instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString()); } return Status::OK(); } @@ -2764,7 +2747,7 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2782,8 +2765,8 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( @@ -2803,7 +2786,7 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique( + return absl::make_unique( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), @@ -2891,7 +2874,7 @@ int IrEmitterUnnested::ConstructIrArrayForInputs( int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays) { int64 num_outputs = 1; @@ -2918,7 +2901,7 @@ int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays) { int64 num_params = hlo.operands().size(); @@ -3059,8 +3042,8 @@ void EmitTiledElementalCodeWithBoundsCheck( // TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient // to launch fewer blocks so each transposes many tiles. LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids) { + HloInstruction* hlo, absl::Span reduced_output_dims, + absl::Span tiled_param_ids) { // Parameters for the tiling algorithm. constexpr int64 kTileSize = 32; constexpr int64 kNumRows = 4; @@ -3105,7 +3088,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( CeilOfRatio(output_dims_in_tiles[i], kTileSize); } const int64 num_tiles = - c_accumulate(output_dims_in_tiles, 1, std::multiplies()); + absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies()); LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); llvm::Type* index_ty = @@ -3151,9 +3134,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( const IrArray::Index output_tile_origin = [&] { IrArray::Index index = output_tile_index; for (int i = 1; i < 3; ++i) { - index[i] = - b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); + index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); } return index; }(); @@ -3166,12 +3148,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( std::vector output_tile_bounds(3); for (int i = 1; i < 3; ++i) { // Only last row or column may not have full size. - output_tile_bounds[i] = b_.CreateSelect( - b_.CreateICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); + output_tile_bounds[i] = + Select(ICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); } KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); @@ -3189,7 +3171,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // Adds `addend` to the given `dim` of `index`. auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = b_.CreateAdd(index[dim], addend); + index[dim] = Add(index[dim], addend); return index; }; const IrArray::Index input_index = @@ -3205,10 +3187,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( llvm::Value* shmem_buffer = param_shmem_buffers[id]; // TODO(jlebar): Add AA metadata to this store. Tile buffers are // global variables, so LLVM can't infer much about it. - b_.CreateStore( - input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); } }); @@ -3229,9 +3210,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_index, "output", output_tile_bounds[2], output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc) { // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( - b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); output_in_reduced_shape_arrays[0].EmitWriteArrayElement( index, load_from_shmem_buffer, &b_); }); @@ -3259,7 +3240,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_in_reduced_shape_arrays.size()); for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, b_.CreateExtractValue(output_value, i), &b_); + index, ExtractValue(output_value, i), &b_); } } else { output_in_reduced_shape_arrays[0].EmitWriteArrayElement( @@ -3308,7 +3289,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { if (!reduced_dims_021.has_value()) { reduced_dims_021 = curr_reduced_dims_021; } - if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) { + if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) { // There is more than one possible transpose. Instead of picking one // transpose, we simply give up here. return false; @@ -3341,7 +3322,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use - // shared memory in fusions. If in the future other fusile ops use shared + // shared memory in fusions. If in the future other fusible ops use shared // memory, we'll have to adjust this heuristic. constexpr int kMinBlocksPerCore = 3; constexpr int64 kShmemPerCore = 48 * 1024; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 525441990795e160ba0e8facb910d5cc9796c4bb..084462330ed20108a9ec850b4cbc588afe77cc01 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -105,13 +105,12 @@ class IrEmitterUnnested : public IrEmitter { // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice args); + absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span> extra_output_gens); // EmitColumnReduction and EmitRowReduction emit code for column and row @@ -127,12 +126,11 @@ class IrEmitterUnnested : public IrEmitter { Status EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a @@ -143,23 +141,21 @@ class IrEmitterUnnested : public IrEmitter { Status EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which @@ -180,13 +176,12 @@ class IrEmitterUnnested : public IrEmitter { // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - tensorflow::gtl::ArraySlice input_gens, - tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers, - tensorflow::gtl::ArraySlice reduce_output_shapes, - tensorflow::gtl::ArraySlice< - std::pair> + absl::Span input_gens, + absl::Span init_value_gens, + absl::Span dimensions_to_reduce, + absl::Span reducers, + absl::Span reduce_output_shapes, + absl::Span> extra_output_gens); // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel @@ -195,10 +190,9 @@ class IrEmitterUnnested : public IrEmitter { // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // returns the launch dimensions for the kernel. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - LaunchDimensions EmitHlo021Tile( - HloInstruction* hlo, - tensorflow::gtl::ArraySlice reduced_output_dims, - tensorflow::gtl::ArraySlice tiled_param_ids); + LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, + absl::Span reduced_output_dims, + absl::Span tiled_param_ids); // Generates the IrArray for each output of hlo and returns the number of // outputs. int ConstructIrArrayForOutputs(const HloInstruction& hlo, @@ -214,7 +208,7 @@ class IrEmitterUnnested : public IrEmitter { int ConstructOutputReducedShapeAndCastOutputIrArrayToShape( const HloInstruction& hlo, const std::vector& output_arrays, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* output_reduced_shapes, std::vector* output_in_reduced_shape_arrays); // For each input of the `hlo` instruction, checks its value in @@ -226,7 +220,7 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& hlo, const std::vector& param_arrays, const std::vector& param_buffers, - tensorflow::gtl::ArraySlice reduced_output_dims, + absl::Span reduced_output_dims, std::vector* param_reduced_shapes, std::vector* param_in_reduced_shape_arrays); diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e76823ad103dfa5ba61a0d3ba81b2c028dfeb33e..e09b8fbd3ba275e14accbf88c21f3d10f34198d9 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -15,22 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice args, - const string& kernel_name, const HloInstruction* hlo_instruction, - int unroll_factor) +KernelThunk::KernelThunk(absl::Span args, + const string& kernel_name, + const HloInstruction* hlo_instruction, + int unroll_factor) : Thunk(Kind::kKernel, hlo_instruction), args_(args.begin(), args.end()), kernel_name_(kernel_name), @@ -41,11 +41,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + loader_spec_->AddCudaPtxInMemory(executable.ptx(), kernel_name_); if (!executable.cubin().empty()) { loader_spec_->AddCudaCubinInMemory( @@ -63,7 +59,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + return InternalError("Unable to load kernel %s", kernel_name_); } } @@ -95,7 +91,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = MakeUnique>(); + auto kernel_args = absl::make_unique>(); for (const BufferAllocation* arg : args_) { const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); kernel_args->add_device_memory_argument(buf); @@ -107,7 +103,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + return InternalError("Unable to launch kernel %s", kernel_name_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index d751de50ad6671b3bf88cd4de49a8feb448e13ba..f63db5c3696f8f3bbd5956724240b2b06b4f1b98 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -47,7 +47,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice args, + KernelThunk(absl::Span args, const string& kernel_name, const HloInstruction* hlo_instruction, int unroll_factor); KernelThunk(const KernelThunk&) = delete; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index eb93efc560efbb4c14065ec98b980a1ca78605c6..698d2d51cc81a6c87f6578f1f35cdb47cf6bb4f2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,9 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index 12a8a59488bfdd6ce55f762926cd63ba56bf9d7f..85bc58cb445627695a46171db64cd8a1f10e0fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) { const llvm::PassInfo *PI = llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( - tensorflow::io::Basename(input_filename_), - tensorflow::strings::Printf( + absl::string_view(tensorflow::io::Basename(input_filename_)), + absl::StrFormat( "pass-%02d.before.%s.ll", i, - (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + absl::string_view(PI == nullptr ? "unknown" + : PI->getPassArgument().data()))); llvm::legacy::PassManager::add( new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index ff4ae1f9ef2ad2fda4bb9100de93019c0b88fbd1..8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -20,13 +20,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -54,10 +56,7 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" @@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, << ", " << compute_capability.second << ") ." << "Defaulting to libdevice for compute_" << libdevice_version; } - return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, - ".10.bc"); + return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc"); } // Gets the GPU name as it's known to LLVM for a given compute capability. If @@ -138,15 +136,16 @@ static string GetSmName(std::pair compute_capability) { << "Defaulting to telling LLVM that we're compiling for sm_" << sm_version; } - return tensorflow::strings::StrCat("sm_", sm_version); + return absl::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, - tensorflow::StringPiece extension) { - return ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); + absl::string_view extension) { + return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( + llvm_ir::AsString(input_filename))), + extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr GetTargetMachine( - llvm::Triple triple, tensorflow::StringPiece cpu_name, + llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); @@ -205,7 +204,7 @@ std::unique_ptr GetTargetMachine( default: codegen_opt_level = CodeGenOpt::None; } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional(RelocModel), Optional(CMModel), codegen_opt_level)); @@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { +void EmitBitcodeToFile(const Module& module, absl::string_view filename) { std::error_code error_code; - llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code, + llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); if (error_code) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); @@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-nvptx.dummy"), + ReplaceFilenameExtension( + absl::string_view(tensorflow::io::Basename(module_id)), + "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return tensorflow::errors::Internal(tensorflow::strings::StrCat( - "Error linking libdevice from ", libdevice_path)); + return tensorflow::errors::Internal( + absl::StrCat("Error linking libdevice from ", libdevice_path)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 54e0e140dea1c3a8b21ffde2950c4bc9b703b71c..9654175bfafbb2521743e7894188abe5b5a15217 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index 9ef9bc3a50fc76f83f05e19163ab339f2da6ef3c..3b2c3591d95ee5a319c82336e9b500d14f88734f 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace { @@ -52,14 +52,13 @@ std::unique_ptr LoadIRModule(const string& filename, return module; } -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension) { +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension) { auto pos = filename.rfind('.'); - tensorflow::StringPiece stem = - pos == tensorflow::StringPiece::npos - ? filename - : tensorflow::StringPiece(filename.data(), pos); - return tensorflow::strings::StrCat(stem, ".", new_extension); + absl::string_view stem = pos == absl::string_view::npos + ? filename + : absl::string_view(filename.data(), pos); + return absl::StrCat(stem, ".", new_extension); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h index a6daeca95a6da66cb31b82805a6896f57cb80354..60f4926849cd3e8ad144f657f9feb3c3e1ea25e2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace llvm { class LLVMContext; @@ -41,8 +41,8 @@ std::unique_ptr LoadIRModule(const string& filename, // // For example: // ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension); +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c62bae0628f7b2fbfe822104fbe5f3528e0e09c3..c21f76f6eb1874bfa5a1d296c78ea0e3b9261eca 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +50,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // If possible, we want to pick a reduce operand of the fusion root, // because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*inst)) { return inst; } } @@ -63,7 +65,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, auto get_element_shape = [&](const HloInstruction* element_instr) { // Special handling of kReduce instructions -- the fusion // applies to the first operand. - if (element_instr->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -85,65 +87,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, get_element_shape(element_instr_1), get_element_shape(element_instr_2)); } -namespace { -bool IsInputFusibleReduction(HloInstruction* instr) { - if (instr->IsMultiOutputFusion()) { - for (const HloInstruction* operand : - instr->fused_expression_root()->operands()) { - if (operand->opcode() == HloOpcode::kReduce) { - CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput) - << " Reduce multi-output fusion " << instr->ToString() - << " must be an input fusion."; - return true; - } - } - return false; - } else if (instr->opcode() == HloOpcode::kFusion) { - // The loop emitter can handle to-vector reduce fusions. Such reduce - // fusions have the fusion kind kLoop rather than kInput. We do not fuse - // to-vector reduce fusions, because the resulting fusions may no longer be - // supported by loop emitter. - return IsReductionToVector(*instr->fused_expression_root()); - } else { - return IsReductionToVector(*instr); - } -} - -// The code emitted for reduction suffers from poor data locality if the layouts -// of input parameters differ. In such situtations it is beneficial not to fuse. -// We consider input params with maximum rank only. Params with smaller ranks -// will be broadcasted and have not been observed to cause data locality issues. -// TODO(b/111977086): Improve reduce emitters to remove this limitation. -bool ReduceFriendlyInputLayouts(HloInstruction* instr) { - std::vector params; - if (instr->opcode() == HloOpcode::kFusion) { - params = instr->fused_parameters(); - } else { - for (HloInstruction* operand : instr->operands()) { - params.push_back(operand); - } - } - int64 max_rank = 0; - const Layout* max_rank_layout; - for (HloInstruction* param : params) { - if (ShapeUtil::Rank(param->shape()) > max_rank) { - max_rank = ShapeUtil::Rank(param->shape()); - max_rank_layout = ¶m->shape().layout(); - } - } - return c_all_of(params, [&](HloInstruction* param) { - return (ShapeUtil::Rank(param->shape()) < max_rank) || - (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); - }); -} - -} // namespace - bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. - return IsInputFusibleReduction(instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop); + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // TODO(b/112957171): This should use the same isFusible logic as + // instruction_fusion. + return instr->IsFusible() && + (IsInputFusibleReduction(*instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr->IsElementwise()); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -177,11 +130,12 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this - // improves things. + // improves things. Also disable fusing standalone input-fusible reduces into + // loop fusions. CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (instr2->opcode() != HloOpcode::kFusion && + (IsReductionToVector(*instr2) && instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { return false; } @@ -197,7 +151,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { tensorflow::gtl::FlatSet to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, - // then filter out instructions that will be no longer fusable because of + // then filter out instructions that will be no longer fusible because of // reachability change. This avoids recalculating reachability on a large set // of instructions. std::vector> @@ -212,8 +166,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << consumer->name() << " has no users."; continue; } - if (!IsInputFusibleReduction(consumer)) { - VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + if (!IsInputFusibleReduction(*consumer)) { + VLOG(3) << consumer->name() << " is not an input-fusible reduction."; continue; } VLOG(3) << consumer->name() @@ -222,8 +176,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; - if (!producer->IsFusable()) { - VLOG(3) << producer->name() << " is not fusable."; + if (!producer->IsFusible()) { + VLOG(3) << producer->name() << " is not fusible."; continue; } const bool is_loop_fusion = @@ -237,7 +191,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " has an incompatible shape."; continue; } - if (!ReduceFriendlyInputLayouts(producer)) { + if (!LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) { VLOG(3) << producer->name() << " has inputs with mixed layouts."; continue; } @@ -248,7 +202,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { @@ -263,12 +217,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } } - // Filter out pairs that will be no longer fusable because of reachability + // Filter out pairs that will be no longer fusible because of reachability // change. for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 67ca5d49eee8508e93284b134f8410eb3a89f9ce..f0b4d67ab8463a39161f71908746cad9e2a8670a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -22,7 +22,7 @@ namespace xla { namespace gpu { // Multi-output fusion of sibling and producer-consumer instructions for the -// Jellyfish backend. +// GPU backend. class GpuMultiOutputFusion : public MultiOutputFusion { public: GpuMultiOutputFusion(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 14f157a5e518a0ec82c664c123629d04bd385bbf..c822c94f1b102e02be4a13a35892a2c181702383 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { +namespace op = xla::testing::opcode_matchers; + using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( @@ -47,7 +47,7 @@ const char kModulePrefix[] = R"( TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) mul = f32[6400]{0} multiply(p1.1, p1.1) @@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { const.1 = f32[] constant(1) p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) @@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -256,8 +256,136 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } -TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { + // Fusing a reduce into a loop fusion would require changing the fusion kind. + // That's not supported yet. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Exp(), op::Add())); +} + +TEST_F(MultiOutputFusionTest, + MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -277,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -304,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) @@ -345,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -372,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) @@ -413,7 +541,7 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionReduceUnfriendlyLoopFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( mixed_input_layouts_computation { p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 1ab663fb2ae424a508084b0165c6812f724f28f3..f6325b33680629b7e3d3814b088582a5007de6dc 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -21,13 +21,15 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -42,9 +44,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -84,7 +86,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -131,11 +132,16 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. +// +// It takes a compiler pointer, as passes may compile and execute HLOs on the +// fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + DeviceMemoryAllocator* device_allocator, + Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -151,7 +157,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -198,8 +205,13 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); + // CudnnConvolutionRewriter may add instructions of the form + // reverse(constant), which it expects will be simplified by constant + // folding. + pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -211,9 +223,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -248,8 +273,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, - device_allocator); + pipeline.AddPass( + stream_exec, device_allocator, compiler); // Clean up new_tuple described above. pipeline.AddPass(); @@ -259,17 +284,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(); + fusion.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); fusion.AddPass(); fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker(); + reduce_pipeline.AddInvariantChecker( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -295,7 +323,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -345,9 +374,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || - !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || - !tensorflow::strings::safe_strto64(vmin_str, &vmin) || - !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path << " --version:\n" << out; @@ -459,7 +488,7 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, - tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -495,11 +524,15 @@ NVPTXCompiler::NVPTXCompiler() StatusOr> NVPTXCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + VLOG(2) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(2, module->ToString()); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); return std::move(module); } @@ -533,8 +566,8 @@ StatusOr> NVPTXCompiler::RunBackend( // must also be used to determine the thunk launch schedule. std::unique_ptr stream_assignment = AssignStreams(*module); TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - HloSchedule::Build(*module, *stream_assignment, pointer_size_)); + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -551,6 +584,7 @@ StatusOr> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); + VLOG(2) << "*** HLO After Optimization"; XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); @@ -662,7 +696,7 @@ StatusOr> NVPTXCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); + ir_dump_directory, absl::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); @@ -678,7 +712,7 @@ StatusOr> NVPTXCompiler::RunBackend( const std::vector cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); - auto thunk_schedule = MakeUnique( + auto thunk_schedule = absl::make_unique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; @@ -692,7 +726,7 @@ StatusOr> NVPTXCompiler::RunBackend( cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = MakeUnique(*module); + profile_index_map = absl::make_unique(*module); profile_printer = CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } @@ -801,7 +835,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index d4d2909f1b2dc57c3ae0f9d67067e533574369dd..8e97774750344bfc141daa7d752300762c708613 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 4aaf0c9e142106a0e74f319d71dad4c4c96d3f08..2fa170964e974a6535307d7a21eb3e7760d02536 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index b99d998c4d7df514c024b1f8d643d08c72059d0e..e0f3e84a4cb25792cf10d38fc529f3e638acf8e4 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index 79f7d31816baf0b95b967771b956a9c06ac81e91..fa84d7722351b68770b876e3880b472eec3233d7 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { -using tensorflow::gtl::ArraySlice; // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. @@ -42,7 +41,7 @@ static constexpr double kMaxBytesTouchedIncrease = 1.2; // Pads the given dimensions in the given shape up to a multiple of // kDesiredNumFeaturesFactor. -static Shape PadShape(Shape s, ArraySlice dims) { +static Shape PadShape(Shape s, absl::Span dims) { for (int64 dim : dims) { int64 dim_to_pad_size = s.dimensions(dim); int64 new_dim_to_pad_size = diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 192359f026bfb2f1d5436713e4a30725fa0ad6ba..11dc56a64fda74cab12024e5f2c6fa2f63c9167d 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -32,9 +32,7 @@ namespace gpu { // TODO(jlebar): Also pad dots. class PadForTensorCores : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "pad for tensor cores"; - } + absl::string_view name() const override { return "pad for tensor cores"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b826fc5cd6d98a037a5eb064552952e18..5c92b0dcb873b873074704dca8f27d4067b070df 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,7 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase {}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b22040eee167e784bed58dbc0d0ad2ae042037f3..9d85d746d84908eaa8d720bc3cccc475d81710f3 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -165,9 +166,9 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { Shape old_conv_shape = conv->shape().tuple_shapes(0); VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, - new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward( + old_conv_shape, new_input, new_kernel, new_conv_window, + conv->convolution_dimension_numbers(), conv->feature_group_count()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); @@ -246,7 +247,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums); + backward_conv_dnums, backward_conv->feature_group_count()); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -311,7 +312,7 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums); + backward_conv_dnums, backward_conv->feature_group_count()); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 67e51509e4c717951c83c7e41943af1de762dee0..a622e894ed9c0d1534262e6b72a5f4ea7b7821ad 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -26,7 +26,7 @@ namespace gpu { // padding, so that they can be lowered to cuDNN convolution. class PadInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "pad insertion"; } + absl::string_view name() const override { return "pad insertion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 3838fee674566196e10ddd98462c1a1aa7835e1a..8154d75d23a6d49153ccb6824402aff73f365617 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -40,7 +40,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( ParallelLoopEmitter::ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, int unroll_factor) : LoopEmitter(target_element_generator, target_arrays, b), @@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( unroll_factor_(unroll_factor) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index b82a23419df08cafdc69b6d2f14528484b95dc73..f32ea1ce4c4192f39851a6441c46663df3063724 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -47,18 +47,17 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { // // This is used in multi-output fusion. target_element_generator should // produce a struct with N elements, one for each of target_arrays. - ParallelLoopEmitter( - const llvm_ir::ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b, - int unroll_factor = 1); + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + absl::Span target_arrays, + const LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* b, int unroll_factor = 1); ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d3fd0544fb68809125e9b9f7a5e5b7eff8c6ef43..cf9f102d31305da15dabaf6247f23c5ca9a9e054 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,9 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", - launch_dims.block_count(), - launch_dims.threads_per_block()); + out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), + launch_dims.threads_per_block()); return out; } @@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions( } int64 block_count = CeilOfRatio(num_elements, threads_per_block); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " - "block) = ceil(%lld/%lld) = %lld", + "block) = ceil(%d/%d) = %d", num_elements, threads_per_block, block_count); return LaunchDimensions(block_count, threads_per_block); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 0806dd51614f4d2da12f3fbbc9fb98df5273d5c8..5b6cf2c04d05378a363232e33a6df6432cd6848e 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -119,7 +119,7 @@ int ComputeStreamToAssign( } // namespace std::unique_ptr AssignStreams(const HloModule& module) { - auto stream_assignment = MakeUnique(); + auto stream_assignment = absl::make_unique(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr reachability = computation.ComputeReachability(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3..091aca23e54bf0585b91e7a05c0837d8a0a2b764 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace gpu { @@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", config); + return absl::make_unique("test_module", config); } // Pre-canned shapes. @@ -97,7 +98,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 05b305ea4cdfdbaeb42544b626a6b9990bb42f57..08ff52211af163fec39646ca6bf14da9d1b815e4 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { @@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, input_layout.push_back(dnums.input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid input layout: ", - DataLayoutString(input)); + return InternalError("Invalid input layout %s for conv with dnums %s", + DataLayoutString(input), + ConvolutionDimensionNumbersToString(dnums)); } std::vector filter_layout; @@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, filter_layout.push_back(dnums.kernel_input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid filter layout: ", - FilterLayoutString(filter)); + return InternalError("Invalid filter layout %s for conv with dnums %s", + FilterLayoutString(filter), + ConvolutionDimensionNumbersToString(dnums)); } std::vector output_layout; @@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, output_layout.push_back(dnums.output_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid output layout: ", - DataLayoutString(output)); + return InternalError("Invalid output layout %s for conv with dnums %s", + DataLayoutString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), @@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(input, nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid input layout: ", - input.ShortDebugString()); + return InternalError("Invalid input layout %s for conv with dnums %s", + LayoutUtil::HumanString(input), + ConvolutionDimensionNumbersToString(dnums)); } FilterLayout filter_layout; @@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(filter, nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return tensorflow::errors::Internal("Invalid filter layout: ", - filter.ShortDebugString()); + return InternalError("Invalid filter layout %s for conv with dnums %s", + LayoutUtil::HumanString(filter), + ConvolutionDimensionNumbersToString(dnums)); } DataLayout output_layout; @@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(output, nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid output layout: ", - output.ShortDebugString()); + return InternalError("Invalid output layout %s for conv with dnums %s", + LayoutUtil::HumanString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, filter_layout, output_layout); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 4fad3f46cf953945e4f395e751e5ba76db97ecc4..db4a33dc564b62b5fe54b725ea453a6fcbfb3287 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -35,13 +35,13 @@ cc_library( "requires-gpu-sm35", ], deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -60,6 +60,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -150,6 +152,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -168,6 +171,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 4b8415fe9106137e588f345a3492f93e46aeb5b6..79e77d4c4d649020cf52ac25c220c3f90e8469b9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -32,15 +32,14 @@ std::unique_ptr GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { debug_options.add_xla_disable_hlo_passes("constant_folding"); config.set_debug_options(debug_options); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str = - std::string(static_cast(executable.get())->ptx()); + string ptx_str(static_cast(executable.get())->ptx()); StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index ce69e058e64aab1f3c292b2ad7c7b529d4666b35..4550f36fdfc097632fed4956fcd3e42ef8a919c5 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index e5958165eff21d82faf821213e50fe30a11059a4..a06576df7b874745236a8d9075355a01ec42e777 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0c472d2a17c466f8cd1af7f22575a8b..15d1e269cc22b88f5269175084f20600f165011c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6c9ae7bada5e7545b558b6fcb872ece60850cbe9..6a9ecd9dae7c9ddde0b56d8615e4a39fb3df0af9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index c42e5704a4d2e611a203293e60a86ba4104bca46..15198865bda98f9718342d5a444a20305f923b48 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 962293630683fcbbce3941f622061a2ff0f02dda..0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index bdb062837c5ba4b588ea0d535a786f33fe4f4015..141f3219387940a08ef22cbcc0be0971a14c2cd6 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -144,16 +144,15 @@ const std::list& ThunkSchedule::DependsOn( string ThunkSchedule::ToString() const { string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - tensorflow::strings::StrAppend(&result, "\t", - thunk->hlo_instruction()->ToString(), "\n"); + absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); } - tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "Dependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - tensorflow::strings::StrAppend( - &result, "\t", dependent->hlo_instruction()->name(), " depends on ", - dependency->hlo_instruction()->name(), "\n"); + absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), + " depends on ", dependency->hlo_instruction()->name(), + "\n"); } } return result; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 8579b1545fd24f80621ac0f53b997e33586cbabe..989b542ff4503600b2e3c751a23345959fab6fd6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" @@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { auto size = tuple_element_buffers_.size(); - auto tuple_element_buffer_addresses = MakeUnique(size); + auto tuple_element_buffer_addresses = absl::make_unique(size); for (int i = 0; i != size; ++i) { tuple_element_buffer_addresses[i] = buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 2d5735d6c40ccd26f0e527f1a02403910db4c812..dcdbf2cf3c2aa87cc11a3473a765cb405b50e2a6 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -34,8 +34,7 @@ namespace gpu { // issue (b/31336476). class TupleThunk : public Thunk { public: - TupleThunk(tensorflow::gtl::ArraySlice - tuple_element_buffers, + TupleThunk(absl::Span tuple_element_buffers, const BufferAllocation::Slice& dest_buffer, const HloInstruction* hlo_instruction) : Thunk(Kind::kTuple, hlo_instruction), diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index d81d87e7dc54cd752000b85f3ec173d66d7195e4..c4754fe378960834e1157b0ff25c03c0fc4754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,9 @@ WhileThunk::WhileThunk( // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. - condition_thunk_sequence_(MakeUnique( + condition_thunk_sequence_(absl::make_unique( std::move(*condition_thunk_sequence), nullptr)), - body_thunk_sequence_(MakeUnique( + body_thunk_sequence_(absl::make_unique( std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, @@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", stream, - block_status.error_message().c_str()); + block_status.error_message()); } if (!condition_result) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c5f3906356d821e059d2b1213c9083c4408a4d1c..40183de96ee363996e6b0b883a78e7a8b5d13ab2 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index aa89567ee86e59e197045c0b51eed3b9aa59fef7..a2be89511babc23ebcd5cb40abee2a95d16dc451 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,9 +22,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -43,8 +43,7 @@ namespace { // Adds a computation to the given HLO module which adds a scalar constant to // its parameter and returns the result. HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = - HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( @@ -84,7 +83,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // the module. std::unique_ptr MakeBigGraph() { HloModuleConfig config; - auto module = MakeUnique("BigGraph", config); + auto module = absl::make_unique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4005fc0d114a3ec7a38dfb5edecdaeb1e8497ade..38c3982ebf170d5733d56a05106835d1eaa4f2e1 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -45,7 +46,7 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, + HeapSimulator::Run(absl::make_unique(), *module, module_sequence, *points_to_analysis, size_function)); return result.heap_size; } @@ -60,9 +61,10 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Run(absl::make_unique(), + computation, sequence, points_to_analysis, + size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -142,7 +144,7 @@ Status HeapSimulator::RunComputation( } } else { // A GetTupleElement doesn't need to keep all of its operand's buffers - // alive. It only needs the buffers that relate to the element its + // alive. It only needs the buffers that relate to the element it's // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. for (const BufferValue* buffer : points_to.element({})) { @@ -275,13 +277,13 @@ Status HeapSimulator::RunComputation( *memory_by_computation_); } - // If the whole module is sequential, we can save memory by running the - // heap-simulation for sub-computations inline. E.g. the buffers for the - // condition and body of a kWhile instruction are only live for the duration - // of the instruction itself. + // If all computations in the module have been scheduled, we can save memory + // by running the heap-simulation for sub-computations inline. E.g. the + // buffers for the condition and body of a kWhile instruction are only live + // for the duration of the instruction itself. // // The order that the sub-computations are simulated does not affect - // correctness; since the whole module is sequential, we know that the + // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || @@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator( const SequentialHloOrdering::HloModuleSequence* module_sequence, const tensorflow::gtl::FlatMap* memory_by_computation) - : no_fragmentation_stats_(MakeUnique()), + : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), @@ -378,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - algorithm_->Alloc(buffer, size); - no_fragmentation_stats_->Alloc(buffer, size); - + const HloInstruction* instruction_to_calc_aliasing = + memory_by_computation_ == nullptr ? nullptr : instruction; + algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); + no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -518,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + if (instruction == nullptr || + (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kCall && + instruction->opcode() != HloOpcode::kConditional)) { + Alloc(buffer, size); + } +} + void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 811a6042df9434ac3f4bed71b9c093433e25c1bb..af05bedee72d4878f83765e5a5c5baf61bd71ba2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -36,6 +36,7 @@ namespace xla { // Forward declare classes defined below. class HeapAlgorithm; +class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular // memory heap with Alloc and Free calls. It only works for completely @@ -161,7 +162,10 @@ class HeapSimulator { const HloInstruction* instruction, const BufferValue* shared_with_canonical); - const std::unique_ptr no_fragmentation_stats_; + // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, + // in which case we are calculating the same allocs/frees twice in the + // simulation. + const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; @@ -216,6 +220,21 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + // NoFragmentationStatsHeap overrides this method. + virtual void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + Alloc(buffer, size); + } + + // Takes memory usage of subcomputations into account when calculating the + // memory usage of a computation. Currently, we don't handle buffer aliasing + // between computations entirely correctly. We are careful to not double count + // for the output buffers of whiles/conds/calls. But we don't take into + // account other aliases, such as for the while init. A more thorough solution + // would require something like BufferAssignment::BuildColocatedBufferSets. + // TODO(b/65835246): + // Since TuplePointsToAnalysis is being replaced with a module-aware alias + // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& @@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) override; + void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b41dc66fe9f5e869a114be96b7cc01fc1a3d59da..5f85f145657b67634844c849447ef545a6dea468 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +138,7 @@ class HeapSimulatorTracker { const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -146,8 +147,8 @@ class HeapSimulatorTracker { // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. auto zero_size = [](const BufferValue& buffer) { return 0; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run( std::move(algorithm), *module_->entry_computation(), instruction_sequence, *points_to_analysis_, zero_size) @@ -156,7 +157,7 @@ class HeapSimulatorTracker { explicit HeapSimulatorTracker(const string& name) { HloModuleConfig config; - module_ = MakeUnique(name, config); + module_ = absl::make_unique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -182,8 +183,8 @@ class HeapSimulatorTracker { auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; - auto algorithm = MakeUnique( - MakeUnique(&actual_calls_)); + auto algorithm = absl::make_unique( + absl::make_unique(&actual_calls_)); result_ = HeapSimulator::Run(std::move(algorithm), *module_, module_sequence, *points_to_analysis_, size_fn) .ConsumeValueOrDie(); @@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); + buffers_.emplace_back( + absl::make_unique(id, const0, ShapeIndex{})); return buffers_.back().get(); } @@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(DecreasingSizeRunsHeapTest, Empty) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Finish(); EXPECT_EQ(call_sequence, CallSequence({ {kFinish, nullptr}, @@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) { TEST_F(DecreasingSizeRunsHeapTest, Simple) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) { TEST_F(DecreasingSizeRunsHeapTest, Mixed) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Free(buffer_b_, 20); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index be9098f555e78f3cabfe55481356f8b6841a3a2b..58b7af93ebfce74951c0f2d65ab226fc94d62e4b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,6 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. +// Next ID: 53 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -45,6 +46,8 @@ message HloInstructionProto { reserved "control_predecessor_names"; reserved 6; reserved "called_computation_names"; + reserved 44; + reserved "replica_group_ids"; string name = 1; string opcode = 2; @@ -74,6 +77,11 @@ message HloInstructionProto { // Describes the dimension numbers used for a convolution. xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + // The number of feature groups. Used for a convolution. Must be a divisor of + // the input feature dimension and output feature dimension. If not specified, + // it will use a default value of 1. + int64 feature_group_count = 50; + // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; @@ -133,7 +141,7 @@ message HloInstructionProto { // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; - repeated int64 gather_window_bounds = 34; + repeated int64 gather_slice_sizes = 34; // Compute Host. string channel_name = 41; @@ -152,9 +160,6 @@ message HloInstructionProto { string backend_config = 43; // Cross replica op fields. - // TODO(b/112107579): remove replica_group_ids field and always use - // replica_groups. - repeated int64 replica_group_ids = 44; repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; @@ -165,6 +170,12 @@ message HloInstructionProto { bool is_host_transfer = 47; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; + + // Precision configuration for the instruction. Has backend-specific meaning. + xla.PrecisionConfigProto precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e8a4b034b4396860bd5873f43003844ce92dea6c..0986da65cbd3d550ecfa01212364518aba651d86 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,15 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; // Data structure used to construct the alias analysis. Thrown away after alias // analysis is complete. This data structure keeps track of which sets of @@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const { } string HloAliasAnalysis::ToString() const { - string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { @@ -457,7 +455,7 @@ StatusOr> HloAliasAnalysis::Run( VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false, @@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference( if (ordering.MayInterfere(*values[i - 1], *values[i], dataflow_analysis())) { VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " - << Join(values, ", ", - [](string* out, const HloValue* value) { - StrAppend(out, value->ToShortString()); - }) + << absl::StrJoin(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) << "\nValue " << values[i - 1]->ToShortString() << " may interfere with value " << values[i]->ToShortString(); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 1fea544730c27efdaa260f55ea81c163165f7ed5..e345804537723f01e9ccb63e7d6ded1bd68f4196 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index da94ab5346e5628b4a603b3ac2d84071904d1e65..54abe3345d25a8cc1fdd66bd6ee75157fe9b7f77 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" @@ -39,15 +39,17 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloTestBase { +class HloAliasAnalysisTest : public HloVerifiedTestBase { protected: - HloAliasAnalysisTest() : module_(CreateNewModule()) {} + HloAliasAnalysisTest() : HloVerifiedTestBase() { + module_ = CreateNewModule(); + } // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. HloAliasAnalysis& RunAnalysis() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before alias analysis"); - analysis_ = HloAliasAnalysis::Run(module_.get(), + analysis_ = HloAliasAnalysis::Run(module_, /*fusion_can_share_buffer=*/nullptr) .ConsumeValueOrDie(); return *analysis_; @@ -91,7 +93,7 @@ class HloAliasAnalysisTest : public HloTestBase { // never occurs, but HLO graphs with interference can be explicitly // constructed. bool AnyValuesInSameBufferInterfere() { - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); for (const HloBuffer& buffer : analysis_->buffers()) { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { @@ -108,7 +110,7 @@ class HloAliasAnalysisTest : public HloTestBase { return false; } - std::unique_ptr module_; + HloModule* module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -461,7 +463,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { module_->AddEntryComputation(builder.Build()); FlattenCallGraph flattener; - TF_ASSERT_OK(flattener.Run(module_.get()).status()); + TF_ASSERT_OK(flattener.Run(module_).status()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -835,7 +837,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { const HloAliasAnalysis& analysis = RunAnalysis(); - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } @@ -877,7 +879,7 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { { // Dependency ordering should interfere because the negate and while are // unordered. - DependencyHloOrdering ordering(module_.get()); + DependencyHloOrdering ordering(module_); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } @@ -888,13 +890,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { sequence[condition] = {cond_param, cond_root}; { sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(module_, sequence); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(module_, sequence); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index e16413f361fb0216792b47c3c67ef3c1357c2221..6c11a073b74c61e44dfe81a32261ae78ae7b46fb 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { @@ -59,10 +56,11 @@ std::vector HloBuffer::ComputePositions() const { } string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return absl::StrCat( + "HloBuffer ", id_, ", values: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index 4873463b2ea4fee3ee39dff31fc3429a4998142f..a88c87e46c8100571aff24f70a2a19fe8ce71ebc 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -84,7 +84,7 @@ class HloBuffer { return a->id() == b->id(); } - HloBuffer(Id id, tensorflow::gtl::ArraySlice values) + HloBuffer(Id id, absl::Span values) : id_(id), values_(values.begin(), values.end()) {} // Return the unique identifier for this HloBuffer. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 441288da1a6859a3f393a298ee02eb4b435e42e0..fe7f2be888d2037e4f6d3879bcc716de4eee07f9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -23,9 +23,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -36,13 +40,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::StrCat; +using absl::StrCat; std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { @@ -56,8 +58,8 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, fusion_instruction_)); + return absl::WrapUnique(new HloComputation( + name_, parameter_count, &instructions_, root, fusion_instruction_)); } HloComputation::HloComputation( @@ -135,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) { } string after_param = original_name.substr(index + param_underscore.size()); int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + if (absl::SimpleAtoi(after_param, &numeric_suffix)) { return StrCat(original_name.substr(0, index + param_underscore.size()), new_param_no); } @@ -317,11 +319,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) { + tensorflow::gtl::FlatMap* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -354,16 +357,71 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + } + break; + } + default: + break; + } } } -} // namespace +HloComputation::ChannelDependencyMap +HloComputation::ComputeChannelDependencies() const { + ChannelDependencyMap channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} std::vector HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -371,7 +429,8 @@ std::vector HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -493,13 +552,13 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction) { CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); HloInstruction* root = instructions_to_fuse.front(); @@ -518,7 +577,7 @@ void HloComputation::FuseInstructionsInto( } HloInstruction* HloComputation::CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind) { HloInstruction* root = instructions_to_fuse.front(); HloInstruction* fusion_instruction = AddInstruction( @@ -566,16 +625,15 @@ StatusOr HloComputation::DeepCopyInstruction( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanString(indices_to_copy->shape())); } ShapeIndex index; @@ -605,7 +663,7 @@ StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } ShapeIndex index; return DeepCopyHelper(instruction, &index, copy_leaf); @@ -624,6 +682,9 @@ ProgramShape HloComputation::ComputeProgramShape() const { } bool HloComputation::operator==(const HloComputation& other) const { + if (this == &other) { + return true; + } std::set> visited; std::function eq = [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { @@ -674,13 +735,37 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); - auto result = MakeUnique(all); + auto result = absl::make_unique(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; @@ -723,11 +808,10 @@ std::vector HloComputation::CollectUnreachableRoots() const { } } VLOG(3) << "Unreachable roots:" - << tensorflow::str_util::Join( - unreachable_roots, "\n\t", - [](string* out, const HloInstruction* hlo) { - tensorflow::strings::StrAppend(out, hlo->ToString()); - }); + << absl::StrJoin(unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + absl::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -829,7 +913,7 @@ std::unique_ptr HloComputation::CloneWithReplacements( HloCloneContext* context, const string& suffix) { std::unique_ptr context_ptr; if (context == nullptr) { - context_ptr = MakeUnique(parent(), suffix); + context_ptr = absl::make_unique(parent(), suffix); context = context_ptr.get(); } @@ -898,12 +982,11 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } -HloInstruction* HloComputation::GetInstructionWithName( - tensorflow::StringPiece name) { +HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); - auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { - return instr->name() == name; - }); + auto it = absl::c_find_if( + instructions_in_computation, + [&](HloInstruction* instr) { return instr->name() == name; }); return it == instructions_in_computation.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 49ed65910f519810740b89760ad815f287e59a91..fe2d3bbbe53bdcb7b2ea8a35f35e50fb3e8823b4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -237,7 +237,7 @@ class HloComputation { // removed if they have no uses after fusion (this is necessarily true for at // least the root). HloInstruction* CreateFusionInstruction( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind); // Create a deep copy of the given instruction and return the instruction @@ -367,7 +367,7 @@ class HloComputation { // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. - HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + HloInstruction* GetInstructionWithName(absl::string_view name); int64 unique_id() const { return unique_id_; } @@ -385,7 +385,7 @@ class HloComputation { // // Pre-condition: fusion_instruction's opcode is kFusion. void FuseInstructionsInto( - tensorflow::gtl::ArraySlice instructions_to_fuse, + absl::Span instructions_to_fuse, HloInstruction* fusion_instruction); // Internal helper for recursive copying of an instruction. Creates and @@ -399,6 +399,20 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector CollectUnreachableRoots() const; + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + tensorflow::gtl::FlatMap>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) const; + string name_; int64 unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index e4c547033139185d5dd4ef37db2d22a6431c1102..f7ed1b0316b213a0f34b1d690229f0173dbd5250 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -} // namespace +TEST_F(HloComputationTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = computation->ComputeReachability(); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7229031c0c7f8bd374cfb495c7d8c11e9ca8b95e..8a45939c61755876555bc35c49d7d6c781f8b4fe 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -38,7 +39,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may // be revised. - auto evaluator = MakeUnique(/*max_loop_iterations=*/0); + auto evaluator = absl::make_unique(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); @@ -51,9 +52,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce, and AfterAll operation. - // TODO(b/35975797): Enable Reduce operation once arbitrary computation - // are supported by the evaluator. + // Skip Constant, Parameter, and AfterAll operation. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this @@ -61,7 +60,6 @@ StatusOr HloConstantFolding::Run(HloModule* module) { if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kAfterAll) { continue; } @@ -73,7 +71,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast) { + if (instruction->opcode() == HloOpcode::kBroadcast || + instruction->opcode() == HloOpcode::kIota) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 331480bd029727fa15476cb9ced2e7b7afd170f3..4557983a9c0b0006cc2189c96a88478d469475c1 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -25,7 +25,7 @@ namespace xla { // computation on constants. class HloConstantFolding : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "constant_folding"; } + absl::string_view name() const override { return "constant_folding"; } // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 64a42c1efc0c788ae8e66fb72b2d9aecec179082..07cd1efc1208309770478885532e0284bdb1fbcc 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -104,8 +105,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { TEST_F(HloConstantFoldingTest, Concatenate) { const struct TestConfig { int concat_dimension; - tensorflow::gtl::ArraySlice dimensions; - tensorflow::gtl::ArraySlice concat_sizes; + absl::Span dimensions; + absl::Span concat_sizes; } test_configs[] = { {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}}, {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}}, @@ -195,12 +196,52 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; root->literal().EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); matched = matched && (value == literal_clone->Get(rindexes)); }); EXPECT_TRUE(matched); } +const char* const kConstantFoldReduce = R"( + HloModule ConstantFoldReduce + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + ENTRY r { + x = s32[3] constant({1, 2, 3}) + init = s32[] constant(0) + ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add + })"; + +TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kConstantFoldReduce)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_EQ(6, module->entry_computation() + ->root_instruction() + ->literal() + .GetFirstElement()); +} + +TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(kConstantFoldReduce)); + HloInstruction* add = module->computations().begin()->root_instruction(); + LayoutUtil::ClearLayout(add->mutable_shape()); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1bbb0ff08e26f626f4c3992a5f20ec4990f7db2d..939b5114c3f8f93ad2d768e77db302ae83e44d17 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -258,10 +258,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { - return Status::OK(); -} - Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -278,15 +274,21 @@ Status HloCostAnalysis::HandleMap(const HloInstruction* map) { } Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { - auto arg = reduce->operand(0); HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); // Compute the cost of all elements for this Reduce operation. - int64 reduction_count = ShapeUtil::ElementsIn(arg->shape()) - - ShapeUtil::ElementsIn(reduce->shape()); + // This counts the number of times the reduction function is applied, so it + // does not need to be multiplied by the number of input tensors - that's + // already "priced in" by the sub-computation doing more work. + auto arg = reduce->operand(0); + auto output_shape = ShapeUtil::IsArray(reduce->shape()) + ? reduce->shape() + : reduce->shape().tuple_shapes(0); + int64 reduction_count = + ShapeUtil::ElementsIn(arg->shape()) - ShapeUtil::ElementsIn(output_shape); for (const auto& property : sub_properties) { if (property.first != kBytesAccessedKey) { current_properties_[property.first] = property.second * reduction_count; @@ -544,15 +546,10 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { } Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { - // TODO(b/110096724): Compute correct cost here. - double flops = 0.0; - ShapeUtil::ForEachSubshape(hlo->shape(), - [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); - current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 193a04bea0831de2b3aca19b17a445ad73e02e49..9bb3f12ee2c7867d71de61c5077f129fdf59ef75 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -72,9 +72,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; - Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 858992a3264a7fe6217374dc1e53f35ef77763c1..19ffb465c04ccc720ba6a8a14b187691a62b5c24 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,15 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; +using absl::StrCat; StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -48,9 +49,9 @@ StatusOr MakePadHlo(HloInstruction* operand, } StatusOr MakeSliceHlo(HloInstruction* operand, - ArraySlice start_indices, - ArraySlice limit_indices, - ArraySlice strides) { + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN(Shape slice_shape, ShapeInference::InferSliceShape( operand->shape(), start_indices, @@ -72,7 +73,7 @@ StatusOr MakeConvolveHlo( } StatusOr MakeTransposeHlo(HloInstruction* operand, - ArraySlice dimensions) { + absl::Span dimensions) { HloComputation* computation = operand->parent(); TF_ASSIGN_OR_RETURN( Shape transpose_shape, @@ -89,15 +90,15 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, } StatusOr MakeReshapeHlo( - ArraySlice result_shape_dim_bounds, HloInstruction* operand) { + absl::Span result_shape_dim_bounds, HloInstruction* operand) { Shape new_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_dim_bounds); return MakeReshapeHlo(new_shape, operand); } -StatusOr MakeDynamicSliceHlo(HloInstruction* operand, - HloInstruction* start_indices, - ArraySlice slice_sizes) { +StatusOr MakeDynamicSliceHlo( + HloInstruction* operand, HloInstruction* start_indices, + absl::Span slice_sizes) { HloComputation* computation = operand->parent(); CHECK_EQ(computation, start_indices->parent()); TF_ASSIGN_OR_RETURN( @@ -123,8 +124,8 @@ StatusOr MakeDynamicUpdateSliceHlo( } StatusOr MakeBroadcastHlo( - HloInstruction* operand, ArraySlice broadcast_dimensions, - ArraySlice result_shape_bounds) { + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds) { HloComputation* computation = operand->parent(); Shape broadcast_shape = ShapeUtil::MakeShape(operand->shape().element_type(), result_shape_bounds); @@ -144,18 +145,18 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, HloInstruction::CreateGetTupleElement(gte_shape, operand, index)); } -StatusOr MakeConcatHlo(ArraySlice operands, - int64 dimension) { +StatusOr MakeConcatHlo( + absl::Span operands, int64 dimension) { CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); - CHECK(c_all_of(operands, [&](HloInstruction* instr) { + CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) { return instr->parent() == computation; })); std::vector operand_shapes; - c_transform(operands, std::back_inserter(operand_shapes), - [](HloInstruction* instr) { return &instr->shape(); }); + absl::c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( operand_shapes, dimension)); @@ -174,9 +175,8 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) { +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation) { CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; HloComputation* computation = operands.front()->parent(); std::vector operand_shapes; @@ -228,19 +228,19 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, const Shape& operand_shape = operand->shape(); new_shape_dims.reserve(n + operand_shape.dimensions_size()); new_shape_dims.insert(new_shape_dims.begin(), n, 1); - c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); return MakeReshapeHlo(new_shape_dims, operand); } StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, ArraySlice expanded_dims) { + HloInstruction* operand, absl::Span expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); CHECK_EQ(operand->shape().dimensions(0), Product(expanded_dims)); std::vector expanded_shape_dim_bounds; expanded_shape_dim_bounds.reserve(expanded_dims.size() + operand->shape().dimensions_size() - 1); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); std::copy(operand->shape().dimensions().begin() + 1, operand->shape().dimensions().end(), std::back_inserter(expanded_shape_dim_bounds)); @@ -249,9 +249,9 @@ StatusOr ExpandFirstDimIntoNDims( return MakeReshapeHlo(new_shape, operand); } -StatusOr ElideDegenerateDims(HloInstruction* operand, - ArraySlice dims_to_elide) { - CHECK(c_is_sorted(dims_to_elide)); +StatusOr ElideDegenerateDims( + HloInstruction* operand, absl::Span dims_to_elide) { + CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); // First accumulate in reverse @@ -268,15 +268,15 @@ StatusOr ElideDegenerateDims(HloInstruction* operand, } } - c_reverse(new_shape_dim_bounds); + absl::c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); return MakeReshapeHlo(output_shape, operand); } StatusOr InsertDegenerateDims( - HloInstruction* operand, ArraySlice dims_to_insert) { - CHECK(c_is_sorted(dims_to_insert)); + HloInstruction* operand, absl::Span dims_to_insert) { + CHECK(absl::c_is_sorted(dims_to_insert)); const Shape& operand_shape = operand->shape(); int64 output_shape_rank = @@ -318,25 +318,25 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, *padding_config.add_dimensions() = padding_config_dim; HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( + HloInstruction::CreateConstant(absl::make_unique( LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + absl::make_unique(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } StatusOr> CreateComputationWithSignature( - ArraySlice domain, const Shape& range, - tensorflow::StringPiece name) { - HloComputation::Builder b{std::string(name)}; + absl::Span domain, const Shape& range, + absl::string_view name) { + HloComputation::Builder b{string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 5ff8946fb098b57ae563a8ade47e8323f807a369..a1c4b374d1121bbf94f5940b52859682808119c4 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -40,10 +40,10 @@ StatusOr MakePadHlo(HloInstruction* operand, // Creates a slice HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeSliceHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); +StatusOr MakeSliceHlo(HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). @@ -53,8 +53,8 @@ StatusOr MakeConvolveHlo( // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. -StatusOr MakeTransposeHlo( - HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); +StatusOr MakeTransposeHlo(HloInstruction* operand, + absl::Span dimensions); // Creates a reshape HLO instruction and adds it to the computation containing // `operand`. @@ -62,15 +62,14 @@ StatusOr MakeReshapeHlo(const Shape& result_shape, HloInstruction* operand); StatusOr MakeReshapeHlo( - tensorflow::gtl::ArraySlice result_shape_dim_bounds, - HloInstruction* operand); + absl::Span result_shape_dim_bounds, HloInstruction* operand); // Creates a dynamic-slice HLO instruction and adds it to the computation // containing `operand` and `start_indices` (`operand` and `start_indices` must // be in the same computation). StatusOr MakeDynamicSliceHlo( HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Creates a dynamic-update-slice HLO instruction and adds it to the computation // containing `operand`, `update` and `start_indices` (`operand`, `update` and @@ -82,9 +81,8 @@ StatusOr MakeDynamicUpdateSliceHlo( // Creates a broadcast HLO instruction and adds it to the computation containing // `operand`. StatusOr MakeBroadcastHlo( - HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions, - tensorflow::gtl::ArraySlice result_shape_bounds); + HloInstruction* operand, absl::Span broadcast_dimensions, + absl::Span result_shape_bounds); // Creates a GetTupleElement HLO instruction and adds it to the computation // containing `operand`. @@ -95,7 +93,7 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, // containing `operands` (`operands` must be non-empty and every element must be // contained in the same computation). StatusOr MakeConcatHlo( - tensorflow::gtl::ArraySlice operands, int64 dimension); + absl::Span operands, int64 dimension); // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). @@ -104,9 +102,8 @@ StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. -StatusOr MakeMapHlo( - tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); +StatusOr MakeMapHlo(absl::Span operands, + HloComputation* map_computation); // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of @@ -138,7 +135,7 @@ StatusOr PrependDegenerateDims(HloInstruction* operand, // For instance if `operand` has shape f32[200,9,7] and expanded_dims is // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7]. StatusOr ExpandFirstDimIntoNDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice expanded_dims); + HloInstruction* operand, absl::Span expanded_dims); // Elides (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_elide` from `operand`. Every dimension in @@ -148,7 +145,7 @@ StatusOr ExpandFirstDimIntoNDims( // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9]. StatusOr ElideDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_elide); + HloInstruction* operand, absl::Span dims_to_elide); // Inserts (via reshape) a set of degenerate dimensions (dimensions containing // exactly one element), `dims_to_insert` into `operand`. The dimensions in @@ -158,7 +155,7 @@ StatusOr ElideDegenerateDims( // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. StatusOr InsertDegenerateDims( - HloInstruction* operand, tensorflow::gtl::ArraySlice dims_to_insert); + HloInstruction* operand, absl::Span dims_to_insert); // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. @@ -171,13 +168,13 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, // broadcast instruction is emitted into `computation`. StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. StatusOr> CreateComputationWithSignature( - tensorflow::gtl::ArraySlice domain, const Shape& range, - tensorflow::StringPiece name); + absl::Span domain, const Shape& range, + absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 60d3e71757d5ce31e025c744e089ff56091d9a43..eb6affadc800d9d5cf7b143386b46f3e8c608e63 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -14,23 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace { -using tensorflow::gtl::ArraySlice; -class HloCreationUtilsTest : public HloTestBase { +class HloCreationUtilsTest : public HloVerifiedTestBase { protected: - static std::unique_ptr CreateModuleWithProgramShape( - PrimitiveType primitive_type, ArraySlice input_shape_dims, - ArraySlice output_shape_dims, HloInstruction** param, + HloModule* CreateModuleWithProgramShape( + PrimitiveType primitive_type, absl::Span input_shape_dims, + absl::Span output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); Shape output_shape = @@ -48,10 +47,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, CollapseFirstNDims(param, 1)); @@ -68,7 +67,7 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, &entry_computation); @@ -93,10 +92,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{1, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, PrependDegenerateDims(param, 1)); @@ -114,7 +113,7 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, &entry_computation); @@ -135,10 +134,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{1, 1}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, PrependDegenerateDims(param, 2)); @@ -155,7 +154,7 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( + HloModule* module = CreateModuleWithProgramShape( S32, /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, &entry_computation); @@ -177,10 +176,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{2}, + /*output_shape_dims=*/{6}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zero_padded_param, @@ -198,10 +197,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - S32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(S32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, @@ -219,10 +218,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { HloInstruction* param; HloComputation* entry_computation; - std::unique_ptr module = CreateModuleWithProgramShape( - F32, - /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, - &entry_computation); + HloModule* module = CreateModuleWithProgramShape(F32, + /*input_shape_dims=*/{}, + /*output_shape_dims=*/{2, 2}, + ¶m, &entry_computation); TF_ASSERT_OK_AND_ASSIGN( HloInstruction * zeros, diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 06484f4012fc091f70df7bc8ec231ce3fcf89669..cb367adf5ef29111838dd6ee1b770394eef1301c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) { for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } + if (instruction->opcode() == HloOpcode::kConstant) { + hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + } return hash; } diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 5e2b348bdda2b31556fb692e24d2bad2e4173ef5..a28c03599a8765da708f37b986010713654647cb 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface { : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations) {} ~HloCSE() override = default; - tensorflow::StringPiece name() const override { return "cse"; } + absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 90fbaa37c5a70a78a9a818b4a8968f3406c671b1..406d712ec6783a310aabc6600b8b70e1a1ae30a9 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bbfb0c253f583b633c4b2c34b2f068b563d3d9e0..6a63681996bc57f4ef16b2405ffc8ce4f003e783 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -29,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -46,8 +46,7 @@ namespace { // // In this case, we should be able to reuse p0 and output, although p0 has // multiple uses. -bool MultiDynamicSliceUseShareSameIndices( - tensorflow::gtl::ArraySlice uses) { +bool MultiDynamicSliceUseShareSameIndices(absl::Span uses) { if (uses.empty()) { return false; } @@ -78,8 +77,8 @@ bool MultiDynamicSliceUseShareSameIndices( } // namespace -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; HloDataflowAnalysis::HloDataflowAnalysis( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -93,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { tensorflow::gtl::FlatSet visited; - tensorflow::gtl::InlinedVector stack; + absl::InlinedVector stack; stack.push_back(inst); while (!stack.empty()) { const HloInstruction* current = stack.back(); @@ -221,7 +220,7 @@ string HloDataflowAnalysis::ToString() const { bool HloDataflowAnalysis::Phi( HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; VLOG(5) << "instruction value set = " @@ -837,7 +836,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Unimplemented( "Computation %s is called in both a parallel (eg, kMap) and " "sequential (eg, kCall) context", - computation->name().c_str()); + computation->name()); } if (call_graph_node.caller_callsites().empty() || call_graph_node.context() == CallContext::kParallel) { @@ -886,7 +885,7 @@ StatusOr> HloDataflowAnalysis::Run( VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis( module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); @@ -976,28 +975,22 @@ Status HloDataflowAnalysis::Verify() const { bool HloDataflowAnalysis::DoesNotUseOperandBuffer( const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + HloInstruction* fusion_param = + user->fused_parameter(use.operand_number); + const HloValue& value = + GetValueDefinedAt(fusion_param, use.operand_index); + return value.uses().empty(); } + return false; } } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index f4abc7a7c7dcfb223067fe946bec0c5ef32f206b..e62c1c2ac81981e1f44f4c7e1479107979576e32 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -138,7 +138,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // REQUIRES: 'operand' is an operand of 'user'. + // 'operand' does not have to be an operand of 'user'. This can be the case + // with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; @@ -201,7 +202,7 @@ class HloDataflowAnalysis { // the given instruction. If skip_top_level is true, then the top level of the // value set of 'instruction' is not modified. bool Phi(HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs); + absl::Span inputs); // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4755c4a0cf8d268b1c47e596a14605eb2c60b36c..d1a96c10f88e3c05e21a6db4eccb46683cd64c4a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); } +// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the +// parameter tuple. +TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto t0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0)); + auto t1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1)); + // Swap the tuple elements. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0})); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); + // The same holds for the parameter tuple, except that the tuple elements are + // swapped in 'tuple'. + EXPECT_TRUE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); + EXPECT_FALSE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); +} + class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 4e244494d6f98c48f4376bd762f116b9a9c2084d..1fe69b1395753a612499e6e87bfc22f8ac8e767b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -36,7 +36,7 @@ namespace xla { class HloDCE : public HloPassInterface { public: ~HloDCE() override {} - tensorflow::StringPiece name() const override { return "dce"; } + absl::string_view name() const override { return "dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 26e3736e01270dbc6ca67647e814843aba2d1e3d..3b5cde2996c4195ef458662cd21de85a832d8d55 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 78955db0da02f16eb93689db947dc1190ab7049a..72185698c9bdcbf2bebed7ee82bc4ed082ce6a14 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr Run(); private: - // Inserts a kDomain instruction between parent and operand, in case - // the attribute (ie, sharding) values change between instruction and operand. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr CreateDomain(HloInstruction* instruction, - HloInstruction* parent, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* parent, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr domain_instruction = - isolator_->creator_(instruction, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); - } - return domain; -} - StatusOr HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -71,16 +50,16 @@ StatusOr HloDomainIsolator::RunContext::Run() { // When applying multiple domains, we could end up stacking more than // one in one edge, so here we want to build the effective // (kDomain-less) instruction->operand edge. - HloInstruction* parent = instruction; - while (operand->opcode() == HloOpcode::kDomain) { - parent = operand; - operand = operand->mutable_operand(0); + HloInstruction* root = operand; + while (root->opcode() == HloOpcode::kDomain) { + root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, parent, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index eded3e78eead76c4564daee119034c5031eba409..d36631fc2f16902ed8f1f89f903027081f9b3801 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -34,14 +34,16 @@ class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the - // second HloInstruction argument). + // third HloInstruction argument) if the interesting attribute of the + // instruction differes from the attribute of the root (the second + // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. - using DomainCreator = std::function( - HloInstruction*, HloInstruction*)>; + using DomainCreator = std::function; explicit HloDomainIsolator(DomainCreator creator); - tensorflow::StringPiece name() const override { return "domain_isolator"; } + absl::string_view name() const override { return "domain_isolator"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 9e096320db5048457435199627a1ef1fe1572177..8b2846e0c277b3e7cffd578d988d0a09c13833ed 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -25,14 +26,14 @@ namespace xla { /* static */ StatusOr> HloDomainMap::Create( HloComputation* computation, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } /* static */ StatusOr> HloDomainMap::Create( HloModule* module, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { TF_RETURN_IF_ERROR(domain_map->Populate(computation)); } @@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { // both sides. for (HloInstruction* operand : instruction->unique_operands()) { if (IsDomainInstruction(operand)) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { - auto domain = MakeUnique(); + auto domain = absl::make_unique(); domain->enter_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } @@ -71,6 +72,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { } Status HloDomainMap::Populate(HloComputation* computation) { + InstructionOrderMap instructions_post_order; + int64 count = 0; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + instructions_post_order.insert(std::make_pair(instruction, count++)); + } for (HloInstruction* instruction : computation->instructions()) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check @@ -84,7 +90,7 @@ Status HloDomainMap::Populate(HloComputation* computation) { continue; } TF_ASSIGN_OR_RETURN(std::unique_ptr domain, - CreateDomain(instruction)); + CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } return Status::OK(); @@ -142,10 +148,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, } StatusOr> HloDomainMap::CreateDomain( - HloInstruction* instruction) const { - auto domain = MakeUnique(); + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const { + auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); - domain->instructions = MakeNonDomainInstructions(domain->reach_set); + domain->instructions = + MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); } @@ -167,7 +175,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set) { + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); for (HloInstruction* instruction : instruction_set) { @@ -175,9 +184,10 @@ HloDomainMap::MakeNonDomainInstructions( instructions.push_back(instruction); } } + // sort instructions according to instructions_order std::sort(instructions.begin(), instructions.end(), - [](HloInstruction* a, HloInstruction* b) { - return a->unique_id() < b->unique_id(); + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 1ca71597253eecfb45ae8f384240033a57045277..633109249a91eec3d7b4cbe5b423b73f980217c9 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -70,6 +70,11 @@ class HloDomainMap { int64 GetDomainId(HloInstruction* instruction) const; private: + // Map used for representing instruction ordering, i.e. + // order_map[a] < order_map[b] means a must be ordered before b. + using InstructionOrderMap = + tensorflow::gtl::FlatMap; + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} // Check if the kDomain instruction is facing (via its operand link) another @@ -95,12 +100,14 @@ class HloDomainMap { // Creates a domain data structure using the ExpandDomain() API. StatusOr> CreateDomain( - HloInstruction* instruction) const; + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const; // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set); + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order); string domain_kind_; std::vector> instruction_domains_; diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index f855f2a1fc944fcc11c9afed278bef4af87813da..6c142ee47421049e8a25dfb80a6297e02fe782f1 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -44,7 +44,10 @@ class DomainMetadata { // two domains of different kind intersect each other. tensorflow::gtl::FlatSet reach_set; - // The same instructions in reach_set, but purged from kDomain instructions. + // The same instructions in reach_set, but purged from kDomain instructions + // and ordered according to their computation graph post-order, i.e. + // if instructions[pos_a] depends on instructions[pos_b], then pos_a > + // pos_b. std::vector instructions; // If we consider a graph edge as an arrow oriented from the operand to the @@ -63,7 +66,7 @@ class DomainMetadata { // Returns the metadata type. A unique identifier which describes the real // metadata type. - virtual tensorflow::StringPiece Kind() const = 0; + virtual absl::string_view Kind() const = 0; // Compares the metadata object with another one and returns true if the // two matches. diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index c859e05f02e54d601804b641094ecdd11bbe1aed..97bc8ef604092acc849b55b09af8a24bf775529e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover(tensorflow::StringPiece kind, + HloDomainRemover(absl::string_view kind, std::function normalizer) - : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + : kind_(kind), normalizer_(std::move(normalizer)) {} - tensorflow::StringPiece name() const override { return "domain_remover"; } + absl::string_view name() const override { return "domain_remover"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 70271be304336767bd3fd01297217e9309a941b6..974ab94467dfb63325698b4590dac1abd1ed9f89 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -45,9 +46,8 @@ class HloDomainTest : public HloVerifiedTestBase { // Checks whether there is a kDomain instruction in the edge between the // instruction and the operand. - bool HasDomainEdge(HloModule* module, - tensorflow::StringPiece instruction_name, - tensorflow::StringPiece operand_name) { + bool HasDomainEdge(HloModule* module, absl::string_view instruction_name, + absl::string_view operand_name) { HloInstruction* instruction = FindInstruction(module, instruction_name); HloInstruction* operand = FindInstruction(module, operand_name); CHECK_NE(instruction, nullptr); @@ -65,7 +65,7 @@ class HloDomainTest : public HloVerifiedTestBase { return false; } - StatusOr ParseModule(tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(absl::string_view hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); ParseAndVerifyModule(hlo_string, config); @@ -80,10 +80,10 @@ class OpNameMetadata : public DomainMetadata { explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} std::unique_ptr Clone() const override { - return MakeUnique(opname_); + return absl::make_unique(opname_); } - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override { const OpNameMetadata* other_ptr = @@ -97,25 +97,26 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - static tensorflow::StringPiece KindName() { return "opname"; } + static absl::string_view KindName() { return "opname"; } private: string opname_; }; // Creator function for OpNameMetadata domains. -std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* operand) { - if (instruction->metadata().op_name() == operand->metadata().op_name()) { +HloInstruction* OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } std::unique_ptr operand_side_metadata = - MakeUnique(operand->metadata().op_name()); + absl::make_unique(root->metadata().op_name()); std::unique_ptr user_side_metadata = - MakeUnique(instruction->metadata().op_name()); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); + absl::make_unique(instruction->metadata().op_name()); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, @@ -142,7 +143,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -184,7 +185,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -211,7 +212,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -248,7 +249,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -302,7 +303,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(CreateShardingDomain); + HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); @@ -344,7 +345,8 @@ ENTRY entry { token = token[] after-all() infeed = ((f32[4], f32[4]), token[]) infeed(token), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} - infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, + sharding={{maximal device=1}, {maximal device=0}} gte0 = f32[4] get-tuple-element(infeed.data), index=0 gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) @@ -356,7 +358,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -378,11 +380,8 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed = FindInstruction(module, "infeed"); - ASSERT_NE(infeed, nullptr); - HloInstruction* infeed_data = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( @@ -445,7 +444,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -474,8 +473,8 @@ ENTRY entry { TEST_F(HloDomainTest, DumpParseNullSharding) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {}); - auto sharding_md_0 = MakeUnique(nullptr); - auto sharding_md_1 = MakeUnique(nullptr); + auto sharding_md_0 = absl::make_unique(nullptr); + auto sharding_md_1 = absl::make_unique(nullptr); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( @@ -490,6 +489,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +// Tuple inputs are domain instructions. TEST_F(HloDomainTest, DomainTuple) { const char* const hlo_string = R"( HloModule Module @@ -497,14 +497,15 @@ HloModule Module ENTRY entry { p0 = f32[4] parameter(0), sharding={maximal device=0} cst = u32[] constant(0), sharding={maximal device=1} - tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + tpl = (u32[], f32[4]) tuple(cst, p0), + sharding={{maximal device=1}, {maximal device=0}} ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} } )"; TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -523,5 +524,168 @@ ENTRY entry { tpl->sharding()); } +TEST_F(HloDomainTest, MultiDomainMultiUser) { + const char* const hlo_string = R"( + HloModule Module + +ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { + %p0 = (f32[4], f32[4]) parameter(0) + %a = f32[4]{0} get-tuple-element(%p0), index=0 + %domain = f32[4] domain(%a), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %b = f32[4] get-tuple-element(%p0), index=1 + %domain.1 = f32[4] domain(%b), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1} + %domain.2 = f32[4] domain(%c), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %d = f32[4] subtract(%domain, %c), + sharding={maximal device=1}, metadata={op_name="D"} + %domain.3 = f32[4] domain(%d), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %e = f32[4] multiply(%c, %d), + sharding={maximal device=1}, metadata={op_name="D"} + %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1} + %domain.4 = f32[4]{0} domain(%f), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + +// Emulate instructions inserted at top and bottom within nested tuple domain. +TEST_F(HloDomainTest, DomainTupleTopBottomInsert) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=1} + p1 = (f32[5], f32[6]) parameter(1), + sharding={{maximal device=1}, {maximal device=0}} + tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1), + sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}} + ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1, + sharding={{maximal device=1}, {maximal device=0}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tuple.0 instruction, in order to test domain sharding + // application. + auto tuple0 = FindInstruction(module, "tuple.0"); + tuple0->clear_sharding(); + + // Insert the following instructons above and below tuple.0, to emulate other + // passes effects: + // COPY.0 + // \ / + // TUPLE.0 + // / \ + // COPY.1 \ + // / \ + // GTE.0 GTE.1 + // | | + // | COPY.2 + // \ / + // \ / + // TUPLE.1 + // | + auto tuple0_users = tuple0->users(); + auto computation = tuple0->parent(); + HloInstruction* copy0 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy, + tuple0->mutable_operand(1))); + TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0)); + + HloInstruction* copy1 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0)); + HloInstruction* gte0 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0)); + HloInstruction* gte1 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1)); + HloInstruction* copy2 = computation->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + HloInstruction* tuple1 = + computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2})); + + for (HloInstruction* user : tuple0_users) { + TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_TRUE(tuple0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tuple0->sharding()); + + EXPECT_TRUE(copy0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy0->sharding()); + + // copy1 has partial information only from gte.0, so in the end it gets no + // sharding at all. During propagation it does propagate the information from + // gte.0 though, enabling Tuple.0 to be fully sharded. + EXPECT_FALSE(copy1->has_sharding()); + + EXPECT_TRUE(gte0->has_sharding()); + EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding()); + + EXPECT_TRUE(gte1->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + gte1->sharding()); + + EXPECT_TRUE(copy2->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy2->sharding()); + + EXPECT_TRUE(tuple1->has_sharding()); + EXPECT_EQ(tuple0->sharding(), tuple1->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc index 751fc677e2d955fd3d9f8970f7c0370a22c054bf..dc514ae3e5c6907f6398805d171e69ee8635d08e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); - kinds.insert(instruction->user_side_metadata().Kind().ToString()); + kinds.insert(string(instruction->user_side_metadata().Kind())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 8e53cf97f8ba9a88140a909ad20c1a938aec8c1f..81d6d69a8c59da2fc77cb2bab808602cd964fdaf 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} - tensorflow::StringPiece name() const override { return "domain_verifier"; } + absl::string_view name() const override { return "domain_verifier"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index b9244b8e9e5f34e7ac5113c8eacb6f8243eea314..72006e17e7e7ec09b62e88d05b695ec9f4c49647 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -151,7 +151,11 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { } TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); - if (!HasOperandType(hlo, eliminate_type_)) { + bool nullary = hlo->operands().empty(); + bool wrong_element_type = hlo->shape().element_type() == eliminate_type_; + bool should_eliminate_type = (nullary && wrong_element_type) || + HasOperandType(hlo, eliminate_type_); + if (!should_eliminate_type) { // If this CHECK fires, then this was an instruction that does not take // the elimination type as an operand but it does return it. This pass // does not have a feature to change the output type in that case, so diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 2b109225d0b192e5c9e4f6d841377ffad8078dc2..44ded2c2faf7c38d1e2f2aae577ddc07089bbb6a 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface { HloElementTypeConverter(PrimitiveType eliminate_type, PrimitiveType replace_with_type); - tensorflow::StringPiece name() const override { - return "element_type_converter"; - } + absl::string_view name() const override { return "element_type_converter"; } // Returns the pass on the module and returns whether the module was modified. StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 51353eea6e72d5a131897f3c3ae312046051103e..441dcad00047311d682c0623964ee63aab341904 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -23,13 +23,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -52,7 +53,6 @@ namespace xla { namespace { -using tensorflow::gtl::ArraySlice; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, @@ -95,11 +95,12 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + auto result = absl::make_unique(shape); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -125,11 +126,12 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { - return compare_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); + auto result = absl::make_unique(shape); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); return std::move(result); } @@ -138,49 +140,62 @@ StatusOr> Compare( HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); - typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); - typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[PRED] = + absl::make_unique>(this); + typed_visitors_[U8] = + absl::make_unique>(this); + typed_visitors_[U16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); + }); + typed_visitors_[U32] = + absl::make_unique>(this); + typed_visitors_[U64] = + absl::make_unique>(this); + typed_visitors_[S8] = absl::make_unique>(this); + typed_visitors_[S16] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); + }); + typed_visitors_[S32] = + absl::make_unique>(this); + typed_visitors_[S64] = + absl::make_unique>(this); typed_visitors_[F16] = - MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + absl::make_unique>(this); + typed_visitors_[F32] = + absl::make_unique>(this); + typed_visitors_[F64] = + absl::make_unique>(this); + typed_visitors_[C64] = + absl::make_unique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = - MakeUnique>(this); - - typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); - }); - typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); - }); + absl::make_unique>(this); + + typed_visitors_[TUPLE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = + absl::make_unique([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + }); } template StatusOr> HloEvaluator::Evaluate( - const HloModule& module, ArraySlice arg_literals) { + const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); evaluated_.clear(); @@ -197,7 +212,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - const HloComputation& computation, ArraySlice arg_literals) { + const HloComputation& computation, + absl::Span arg_literals) { CHECK(computation.parent() != nullptr); XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); @@ -214,9 +230,8 @@ StatusOr> HloEvaluator::Evaluate( template StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction, ArraySlice arg_literals) { + HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); evaluated_.clear(); arg_literals_.clear(); @@ -253,7 +268,6 @@ StatusOr> HloEvaluator::Evaluate( return tensorflow::errors::FailedPrecondition( "Not all operands are constants."); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); evaluated_.clear(); @@ -378,7 +392,7 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { } Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { - ArraySlice operands(concatenate->operands()); + absl::Span operands(concatenate->operands()); // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -423,7 +437,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type()).c_str()); + PrimitiveType_Name(operand->shape().element_type())); } switch (operand->shape().element_type()) { @@ -464,9 +478,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(compare->shape()), + ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -555,43 +569,41 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -// Returns an ShapeUtil::IndexIterationSpace that iterates over the output -// gather dimensions while keeping the rest of the output dimensions clamped to -// 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputGatherIndices( +// Returns an ShapeUtil::IndexIterationSpace that iterates over the output batch +// dimensions while keeping the rest of the output dimensions clamped to 0. +ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) { int64 output_rank = output_shape.dimensions_size(); std::vector index_base(output_rank, 0); std::vector index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_gather_dim = - !c_binary_search(dim_numbers.output_window_dims(), i); - index_count.push_back(is_output_gather_dim ? output_shape.dimensions(i) - : 1); + bool is_output_batch_dim = + !absl::c_binary_search(dim_numbers.offset_dims(), i); + index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } return {std::move(index_base), std::move(index_count), std::vector(output_rank, 1)}; } -// Return an ShapeUtil::IndexIterationSpace that iterates over the output window +// Return an ShapeUtil::IndexIterationSpace that iterates over the output slice // dimensions while keeping the rest of the output dimensions clamped to 0. -ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( - int64 output_rank, ArraySlice window_bounds, +ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( + int64 output_rank, absl::Span slice_sizes, const GatherDimensionNumbers& dim_numbers) { std::vector index_base(output_rank, 0); std::vector index_count(output_rank, 1); - int64 window_bounds_idx = 0; + int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { bool is_output_window_dim = - c_binary_search(dim_numbers.output_window_dims(), i); + absl::c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.elided_window_dims(), - window_bounds_idx)) { - window_bounds_idx++; + while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { + slice_sizes_idx++; } - index_count[i] = window_bounds[window_bounds_idx++]; + index_count[i] = slice_sizes[slice_sizes_idx++]; } } @@ -599,30 +611,30 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputWindowIndices( std::vector(output_rank, 1)}; } -// This functor computes the contribution of gather_indices to an input index +// This functor computes the contribution of start_indices to an input index // corresponding to an output index. That is, given an output index I, it picks -// out the gather output indices in I and uses them to look up a gather index, -// G, from the gather indices tensor, and expands G into the input space -// according to gather_dims_to_operand_dims. -class OutputGatherIndexToInputIndex { +// out the batch indices in I and uses them to look up a starting index, G, from +// the start indices tensor, and expands G into the input space according to +// start_index_map. +class OutputBatchIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputGatherIndexToInputIndex( + explicit OutputBatchIndexToInputIndex( const GatherDimensionNumbers* dim_numbers, const Shape& input_shape, - const Shape& output_shape, const Literal* gather_indices) - : dim_numbers_(*dim_numbers), gather_indices_(*gather_indices) { + const Shape& output_shape, const Literal* start_indices) + : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - output_dim_is_gather_dims_.push_back( - !c_binary_search(dim_numbers_.output_window_dims(), i)); + output_dim_is_batch_dims_.push_back( + !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = - std::distance(dim_numbers_.gather_dims_to_operand_dims().begin(), - c_find(dim_numbers_.gather_dims_to_operand_dims(), i)); + std::distance(dim_numbers_.start_index_map().begin(), + absl::c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == - dim_numbers_.gather_dims_to_operand_dims_size()) { + dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); } else { input_dim_value_to_index_vector_.push_back( @@ -630,14 +642,14 @@ class OutputGatherIndexToInputIndex { } } - index_vector_index_.resize(gather_indices_.shape().dimensions_size()); + index_vector_index_.resize(start_indices_.shape().dimensions_size()); input_index_.resize(input_shape.dimensions_size()); int64 index_vector_size = - gather_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); + start_indices_.shape().dimensions(dim_numbers_.index_vector_dim()); index_vector_.resize(index_vector_size); } - // Returns the contribution of gather_indices to the input index corresponding + // Returns the contribution of start_indices to the input index corresponding // to output_index. See gather_inner_loop_body. // // This is conceptually a stateless transformation from output_index to the @@ -650,24 +662,25 @@ class OutputGatherIndexToInputIndex { // index_vector_index_ and index_vector on every invocation, we reuse the // same storage for all invocations. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return ArraySlice(input_index_); + return absl::Span(input_index_); } private: - // Propagates the gather index dimensions from the output index into + // Propagates the batch dimensions from the output index into // index_vector_index_ by mutating index_vector_index_ in place. Does not // update the dim_numbers.index_vector_dim() dimension -- that's the dimension // we iterate over in FetchIndexVector. void PropagateOutputIndexGatherDimsToIndexVectorIndex( - ArraySlice output_index) { + absl::Span output_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = output_index.size(); i < e; i++) { - if (!output_dim_is_gather_dims_[i]) { + if (!output_dim_is_batch_dims_[i]) { continue; } @@ -679,14 +692,14 @@ class OutputGatherIndexToInputIndex { } } - // Populates index_vector_ by iterating over gather_indices_ according to + // Populates index_vector_ by iterating over start_indices_ according to // index_vector_index_. Status FetchIndexVector() { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], gather_indices_.GetIntegralAsS64( - index_vector_index_)); + TF_ASSIGN_OR_RETURN(index_vector_[i], + start_indices_.GetIntegralAsS64(index_vector_index_)); } return Status::OK(); } @@ -708,40 +721,39 @@ class OutputGatherIndexToInputIndex { // PropagateIndexVectorToInputIndex. std::vector input_dim_value_to_index_vector_; - // output_dim_is_gather_dims_[i] is true iff the output index i is a gather + // output_dim_is_batch_dims_[i] is true iff the output index i is a gather // dimension. - std::vector output_dim_is_gather_dims_; + std::vector output_dim_is_batch_dims_; - // The buffer into which we construct an index into gather_indices_ to fetch + // The buffer into which we construct an index into start_indices_ to fetch // the index vector. std::vector index_vector_index_; - // The index vector fetched from gather_indices_. + // The index vector fetched from start_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; const GatherDimensionNumbers& dim_numbers_; - const Literal& gather_indices_; + const Literal& start_indices_; }; -// This functor computes the contribution of the window indices in an output +// This functor computes the contribution of the offset indices in an output // index to an input index. That is, given an output index I it picks out the -// output window indices in I and expands it into a window index into the input -// shape. -class OutputWindowIndexToInputIndex { +// output offset indices in I and expands it into an index into the input shape. +class OutputOffsetIndexToInputIndex { public: // The constructor does some setup work that is amortized across all // iterations. - explicit OutputWindowIndexToInputIndex( + explicit OutputOffsetIndexToInputIndex( const GatherDimensionNumbers& dim_numbers, const Shape& input_shape, const Shape& output_shape) { std::vector window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.output_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -750,7 +762,7 @@ class OutputWindowIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( @@ -769,10 +781,11 @@ class OutputWindowIndexToInputIndex { // gather input index on every invocation we reuse the same storage for the // result (input_index_), mutating it in place. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()(ArraySlice output_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); - return ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding output dimension index, @@ -785,7 +798,7 @@ class OutputWindowIndexToInputIndex { // Propagates window dimensions from the output index to input_index_ by // mutating input_index_ in place. void PropagateOutputIndexWindowDimsToInputIndex( - ArraySlice output_index) { + absl::Span output_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_output_index_[i] != -1) { input_index_[i] = output_index[input_dim_value_to_output_index_[i]]; @@ -801,27 +814,27 @@ class OutputWindowIndexToInputIndex { // PropagateOutputIndexWindowDimsToInputIndex. std::vector input_dim_value_to_output_index_; - // The result computed by this functor. operator() returns an ArraySlice into + // The result computed by this functor. operator() returns a Span into // this vector. std::vector input_index_; }; // Rehapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if -// there is one) to `reshaped_gather_indices`. +// there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( - int64 index_vector_dim, const Literal& gather_indices, - std::unique_ptr* reshaped_gather_indices) { - if (gather_indices.shape().dimensions_size() != index_vector_dim) { - return std::cref(gather_indices); + int64 index_vector_dim, const Literal& start_indices, + std::unique_ptr* reshaped_start_indices) { + if (start_indices.shape().dimensions_size() != index_vector_dim) { + return std::cref(start_indices); } - std::vector new_shape(gather_indices.shape().dimensions().begin(), - gather_indices.shape().dimensions().end()); + std::vector new_shape(start_indices.shape().dimensions().begin(), + start_indices.shape().dimensions().end()); new_shape.push_back(1); - TF_ASSIGN_OR_RETURN(*reshaped_gather_indices, - gather_indices.Reshape(new_shape)); - return std::cref(**reshaped_gather_indices); + TF_ASSIGN_OR_RETURN(*reshaped_start_indices, + start_indices.Reshape(new_shape)); + return std::cref(**reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { @@ -830,68 +843,67 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_gather_indices; + std::unique_ptr reshaped_start_indices; TF_ASSIGN_OR_RETURN( - const Literal& gather_indices, + const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), GetEvaluatedLiteralFor(gather->operand(1)), - &reshaped_gather_indices)); + &reshaped_start_indices)); // We iterate over the gather dimensions in the output shape in an outer loop // nest, and iterate over the window dimensions in the output shape in an // inner loop nest. - ShapeUtil::IndexIterationSpace gather_indices_iteration_space = - IterationSpaceForOutputGatherIndices(shape, dim_numbers); - ShapeUtil::IndexIterationSpace window_indices_iteration_space = - IterationSpaceForOutputWindowIndices( - shape.dimensions_size(), gather->gather_window_bounds(), dim_numbers); + ShapeUtil::IndexIterationSpace start_indices_iteration_space = + IterationSpaceForOutputBatchIndices(shape, dim_numbers); + ShapeUtil::IndexIterationSpace offset_indices_iteration_space = + IterationSpaceForOutputOffsetIndices( + shape.dimensions_size(), gather->gather_slice_sizes(), dim_numbers); // Scratch buffers that hold an index in the output shape and the // corresponding index in the input shape. std::vector input_index(operand.shape().dimensions_size()); std::vector output_index(gather->shape().dimensions_size()); - std::vector input_gather_index_clamped( - operand.shape().dimensions_size()); + std::vector input_index_clamped(operand.shape().dimensions_size()); - OutputGatherIndexToInputIndex output_gather_index_to_input_index( + OutputBatchIndexToInputIndex output_batch_index_to_input_index( &gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), - /*output_shape=*/shape, &gather_indices); - OutputWindowIndexToInputIndex output_window_index_to_input_index( + /*output_shape=*/shape, &start_indices); + OutputOffsetIndexToInputIndex output_offset_index_to_input_index( gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(), /*output_shape=*/shape); const Shape& operand_shape = operand.shape(); auto gather_inner_loop_body = - [&](ArraySlice output_window_index, - ArraySlice input_gather_index, - ArraySlice output_gather_index) -> StatusOr { + [&](absl::Span output_window_index, + absl::Span input_gather_index, + absl::Span output_gather_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - ArraySlice input_window_index, - output_window_index_to_input_index(output_window_index)); + absl::Span input_window_index, + output_offset_index_to_input_index(output_window_index)); for (int i = 0, e = output_index.size(); i < e; i++) { output_index[i] = output_gather_index[i] + output_window_index[i]; DCHECK_LT(output_index[i], shape.dimensions(i)); } for (int i = 0, e = input_gather_index.size(); i < e; i++) { int64 output_dim = - output_window_index_to_input_index.input_dim_value_to_output_index(i); + output_offset_index_to_input_index.input_dim_value_to_output_index(i); // If 'output_dim' is -1, it means 'i' is an elided window dim. This means // we set the iteration index to 0, so for the purpose of the following // calculations we can consider the output dimension size to be 1. int64 output_dim_size = output_dim == -1 ? 1 : shape.dimensions(output_dim); // Clamp the gather index so that the gather region fits in the operand. - // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0, + // input_index_clamped[i] = clamp(input_gather_index[i], 0, // operand_shape.dimensions(i) - // output_dim_size); - input_gather_index_clamped[i] = + input_index_clamped[i] = std::min(operand_shape.dimensions(i) - output_dim_size, std::max(0LL, input_gather_index[i])); } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_gather_index_clamped[i] + input_window_index[i]; + input_index[i] = input_index_clamped[i] + input_window_index[i]; DCHECK_GE(input_index[i], 0); DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } @@ -901,19 +913,18 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { }; auto gather_outer_loop_body = - [&](ArraySlice output_gather_index) -> StatusOr { - TF_ASSIGN_OR_RETURN( - ArraySlice input_gather_index, - output_gather_index_to_input_index(output_gather_index)); + [&](absl::Span output_gather_index) -> StatusOr { + TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, + output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, window_indices_iteration_space, + shape, offset_indices_iteration_space, std::bind(gather_inner_loop_body, std::placeholders::_1, input_gather_index, output_gather_index))); return true; }; TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( - shape, gather_indices_iteration_space, gather_outer_loop_body)); + shape, start_indices_iteration_space, gather_outer_loop_body)); evaluated_[gather] = std::move(result); return Status::OK(); } @@ -960,7 +971,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = MakeUnique( + evaluated_[get_tuple_element] = absl::make_unique( ShapeUtil::GetTupleElementShape(operand->shape(), index)); return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, /*dest_shape_index=*/{}, @@ -1098,8 +1109,8 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloEvaluator loop_body_evaluator(max_loop_iterations_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", - while_hlo->name().c_str(), max_loop_iterations_); + return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", + while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( *cond_comp, {lcv.get()})); @@ -1162,12 +1173,12 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = MakeUnique(keys_literal.shape()); - result_keys_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_keys)); - auto result_values_literal = MakeUnique(values_literal.shape()); + auto result_keys_literal = absl::make_unique(keys_literal.shape()); + result_keys_literal->PopulateR1(absl::Span(result_keys)); + auto result_values_literal = + absl::make_unique(values_literal.shape()); result_values_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_values)); + absl::Span(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; @@ -1180,8 +1191,9 @@ StatusOr> EvaluateSortInternal( } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = MakeUnique(keys_literal.shape()); - auto values_result_literal = MakeUnique(values_literal.shape()); + auto keys_result_literal = absl::make_unique(keys_literal.shape()); + auto values_result_literal = + absl::make_unique(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, @@ -1253,7 +1265,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); if (sort_dim != rank - 1) { return Unimplemented( - "Trying to support along dimension %lld, which is not the last " + "Trying to sort along dimension %d, which is not the last " "dimension", sort_dim); } @@ -1272,9 +1284,25 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { } } +Status HloEvaluator::HandleReduce(HloInstruction* reduce) { + if (!ShapeUtil::IsTuple(reduce->shape())) { + return DefaultAction(reduce); + } else { + auto first_element_type = reduce->shape().tuple_shapes(0).element_type(); + for (const auto& tuple_shape : reduce->shape().tuple_shapes()) { + if (tuple_shape.element_type() != first_element_type) { + return Unimplemented( + "Reduce with several outputs that have mixed element types is " + "unsupported"); + } + } + return reduce->Visit(typed_visitors_.at(first_element_type).get()); + } +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); - return Status::OK(); + return ShapeUtil::ValidateShape(hlo->shape()); } Status HloEvaluator::Postprocess(HloInstruction* hlo) { @@ -1286,26 +1314,27 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { // Explicit instantiation of templatized Evaluate* methods. // template StatusOr> -HloEvaluator::Evaluate(const HloModule& module, - ArraySlice arg_literals); +HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( - const HloModule& module, ArraySlice> arg_literals); + const HloModule& module, + absl::Span> arg_literals); -template StatusOr> -HloEvaluator::Evaluate(const HloComputation& computation, - ArraySlice arg_literals); +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloComputation& computation, + absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( const HloComputation& computation, - ArraySlice> arg_literals); + absl::Span> arg_literals); template StatusOr> -HloEvaluator::Evaluate(HloInstruction* instruction, - ArraySlice arg_literals); +HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals); template StatusOr> HloEvaluator::Evaluate>( HloInstruction* instruction, - ArraySlice> arg_literals); + absl::Span> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index a4c37ef32827892194da070ee05ec6dc4f4c306f..c2d49e56ac487ee8a5cb3d26aee497ade63aa844 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -51,8 +51,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // type. template StatusOr> Evaluate( - const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + const HloModule& module, absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -75,7 +74,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { template StatusOr> Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -87,8 +86,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // type. template StatusOr> Evaluate( - HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + HloInstruction* instruction, absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. @@ -185,6 +183,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; + Status HandleReduce(HloInstruction* reduce) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -222,13 +222,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); + ShapeUtil::HumanString(shape), + ShapeUtil::HumanString(operand->shape())); } - auto result = MakeUnique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + auto result = absl::make_unique(shape); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 3ac6d68df30955d2e5e06e1e76d2182772151b47..7e490d7f324022fdf02c569fc1986d0b6f5823ba 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -51,12 +52,15 @@ static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { - evaluator_ = MakeUnique(); + HloEvaluatorTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + use_bfloat16_(GetParam()) { + evaluator_ = absl::make_unique(); } std::unique_ptr Evaluate( - tensorflow::gtl::ArraySlice arg_literals = {}) { + absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -340,7 +344,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( - [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); @@ -523,7 +527,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(8, 5, 1, 1); + auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); (*expected_array)(1, 0, 0, 0) = 1.0f; (*expected_array)(1, 2, 0, 0) = 2.0f; @@ -547,7 +551,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -568,7 +572,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } - auto expected_array = MakeUnique>(1, 5); + auto expected_array = absl::make_unique>(1, 5); (*expected_array)(0, 0) = 7.0f; (*expected_array)(0, 1) = 2.718f; (*expected_array)(0, 2) = 2.718f; @@ -588,7 +592,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique>(4, 3); + auto input_array = absl::make_unique>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); HloInstruction* input_instruction = @@ -612,7 +616,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr result = Evaluate(); - auto expected_array = MakeUnique>(0, 9); + auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -628,7 +632,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // { 3 }, // { 4 }, // } - auto lhs_array = MakeUnique>(4, 1); + auto lhs_array = absl::make_unique>(4, 1); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -679,7 +683,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -710,7 +714,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto lhs_array = MakeUnique>(4, 3); + auto lhs_array = absl::make_unique>(4, 3); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D(*lhs_array); HloInstruction* lhs_instruction = @@ -722,7 +726,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique>(3, 2); + auto rhs_array = absl::make_unique>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D(*rhs_array); HloInstruction* rhs_instruction = @@ -931,7 +935,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); @@ -1008,7 +1012,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); - Array4D expected_array_bf16({{{{2512, 2672}}}}); + Array4D expected_array_bf16({{{{2512, 2688}}}}); // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); @@ -1297,7 +1301,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1339,7 +1343,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1390,7 +1394,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique>(2, 3); + auto arg_array = absl::make_unique>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D(*arg_array); @@ -1511,7 +1515,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // { 9, 10, 11, 12, 13 }, // { 17, 18, 19, 20, 21 }, // } - auto operand_array = MakeUnique>(3, 5); + auto operand_array = absl::make_unique>(3, 5); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1544,7 +1548,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1580,7 +1584,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique>(2, 4); + auto operand_array = absl::make_unique>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1614,7 +1618,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1651,7 +1655,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); auto operand_literal2 = LiteralUtil::CreateR2FromArray2D(*operand_array); @@ -1687,7 +1691,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique>(2, 3); + auto operand_array = absl::make_unique>(2, 3); operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( @@ -1826,21 +1830,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1851,21 +1854,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1876,22 +1878,22 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( *LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1902,11 +1904,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1914,11 +1916,11 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1930,11 +1932,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; ParseAndVerifyModule(hlo_text); @@ -1942,11 +1944,11 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1957,21 +1959,20 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1982,21 +1983,21 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2007,20 +2008,19 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2031,21 +2031,21 @@ ENTRY main { operand = s32[3] parameter(0) indices = s32[2,2,1] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1} + slice_sizes={1} } )"; ParseAndVerifyModule(hlo_text); std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), gather_indices.get()}))); + *Evaluate({operand.get(), start_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2517,6 +2517,31 @@ TEST_P(HloEvaluatorTest, DoesCompareBF16) { std::move(rhs)); } +TEST_P(HloEvaluatorTest, Bf16Reduction) { + const string hlo_text = R"( +HloModule Bf16Reduction + +add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs) +} + +ENTRY main { + arg0 = bf16[4]{0} parameter(0) + init = bf16[] constant(0) + ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16 +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr arg = LiteralUtil::CreateR1( + {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); + std::unique_ptr expected = + LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); +} + INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, ::testing::ValuesIn(use_bf16_params)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 084b49b4783fe15e91917317d8b3746e2c7569d0..cb27e13e99c0192a9796d3d32eba2637e7db06bc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,11 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -86,6 +91,29 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // of this class. template class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + private: + // Get the value in the given literal static_cast as a double. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + absl::Span input_index) { + return static_cast(literal.Get(input_index)); + } + + // Specialization for complex types. In this case it is not possible to + // static_cast value to a double so just CHECK fail. This method is not used + // at run-time, but must be available at compile-time to keep the compiler + // happy. + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + double GetAsDouble(const Literal& literal, + absl::Span input_index) { + LOG(FATAL) << "Trying to get complex literal as double: " + << literal.ToString(); + } + public: explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} @@ -117,7 +145,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } // TODO(b/35950897): many of the stl functions used in the handlers are not @@ -525,7 +553,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) override { + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleDivide(HloInstruction* divide) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { @@ -534,6 +566,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp( + divide, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT { + if (rhs_elem == 0) { + return static_cast(-1); + } + if (rhs_elem == -1 && + lhs_elem == std::numeric_limits::min()) { + return lhs_elem; + } + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return rhs_elem == 0 + ? std::numeric_limits::max() + : (lhs_elem / rhs_elem); + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) { + return HandleDivide(divide); + } + template ::value>::type* = nullptr> @@ -620,9 +692,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, @@ -632,6 +703,40 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value>::type* = + nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el); + })); + return Status::OK(); + } + + template ::value && + std::is_integral::value>::type* = + nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp( + remainder, + [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT { + if (rhs_el == 0) { + return lhs_el; + } + if (rhs_el == -1 && + lhs_el == std::numeric_limits::min()) { + return 0; + } + return lhs_el % rhs_el; + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -873,10 +978,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice out_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -943,8 +1048,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data]( - tensorflow::gtl::ArraySlice out_index) { + rhs_literal_data](absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1025,12 +1129,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { static_cast(rhs_literal_data[rhs_linear_index]); } cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + } while (IndexUtil::BumpIndices(window_shape, + absl::MakeSpan(rhs_spatial_index))); return static_cast(result_val); }; - auto result = MakeUnique(result_shape); + auto result = absl::make_unique(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); @@ -1078,7 +1183,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // result_index_locations[i] contains one or two pointers to the locations // in lhs_index or rhs_index where the i'th result index should go. - tensorflow::gtl::InlinedVector, kInlineRank> + absl::InlinedVector, kInlineRank> result_index_locations; result_index_locations.reserve(lhs_rank + rhs_rank - 2); @@ -1093,20 +1198,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Then we have the LHS and RHS non-contracting dimensions, if any: for (int64 i = 0; i < lhs_rank; i++) { if (i != lhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { result_index_locations.push_back({&lhs_index[i], nullptr}); } } for (int64 i = 0; i < rhs_rank; i++) { if (i != rhs_contracting_dimension && - !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) { + !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { result_index_locations.push_back({&rhs_index[i], nullptr}); } } - auto result = MakeUnique(dot->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice result_index) { + auto result = absl::make_unique(dot->shape()); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1153,11 +1258,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = MakeUnique(pad->shape()); + auto result = absl::make_unique(pad->shape()); TF_RETURN_IF_ERROR(result->Populate( - [&scalar](tensorflow::gtl::ArraySlice multi_index) { - return scalar; - })); + [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); @@ -1170,7 +1273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // corresponding index of the resulting padded literal. const PaddingConfig& pad_config = pad->padding_config(); - auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto func = [&](absl::Span input_index) { for (auto i = 0; i < input_index.size(); ++i) { // Interior padding occurs logically before edge padding, so in the case // of negative edge padding elements are removed from the @@ -1318,11 +1421,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = MakeUnique(map->shape()); + auto result = absl::make_unique(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { std::vector> arg_literals; arg_literals.reserve(operands.size()); @@ -1432,9 +1535,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = MakeUnique(keys_literal.shape()); - result_literal->PopulateR1( - tensorflow::gtl::ArraySlice(result_data)); + auto result_literal = absl::make_unique(keys_literal.shape()); + result_literal->PopulateR1(absl::Span(result_data)); VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); return result_literal; }; @@ -1444,7 +1546,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = MakeUnique(keys_literal.shape()); + auto result_literal = absl::make_unique(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, @@ -1472,20 +1574,20 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleSort(sort); } - Status HandleReduce(HloInstruction* reduce) override { - // TODO(b/112040122): Support variadic reduce. - if (!ShapeUtil::IsArray(reduce->shape())) { - return Unimplemented("Variadic reduce is not supported in the Evaluator"); - } - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + Status HandleReduce(HloInstruction* hlo) override { + HloReduceInstruction* reduce = Cast(hlo); + int64 num_args = reduce->inputs().size(); + bool has_tuple_output = ShapeUtil::IsTuple(reduce->shape()); + absl::Span dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); + + absl::InlinedVector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); + } TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReduceShape( - {&arg->shape(), &init_value->shape()}, + operand_shapes, /*dimensions_to_reduce=*/dimensions, /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) @@ -1493,14 +1595,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); + absl::InlinedVector arg_literals(num_args); + absl::InlinedVector init_literals(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_literals[i] = &parent_->GetEvaluatedLiteralFor(reduce->inputs()[i]); + VLOG(3) << "HandleReduce arg_literal: " << arg_literals[i]->ToString(); + init_literals[i] = + &parent_->GetEvaluatedLiteralFor(reduce->init_values()[i]); + VLOG(3) << "HandleReduce init_literal: " << init_literals[i]->ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literals[i]->shape())); + } - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + // All args and results have the same dimensions, so pick an arbitrary one. + const Shape& arg_shape = arg_literals[0]->shape(); + const Shape& result_shape = ShapeUtil::IsTuple(reduce->shape()) + ? reduce->shape().tuple_shapes(0) + : reduce->shape(); + const auto arg_dimensions = AsInt64Slice(arg_shape.dimensions()); std::vector arg_dim_steps(arg_dimensions.size()); std::vector arg_dim_counts(arg_dimensions.size()); for (const int64 dim : dimensions) { @@ -1518,61 +1629,110 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce->shape()); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = init_scalar; + absl::InlinedVector, 1> results(num_args); + for (int64 i = 0; i < num_args; ++i) { + results[i] = absl::make_unique(result_shape); + } - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } + Status eval_status; + // For each resulting dimension, calculate and assign computed values. + // This is really wasteful when num_args > 1, since we re-run the + // reduction num_args time. The alternative is to teach Populate() about + // tuples, which we should probably do. + absl::InlinedVector init_scalars(num_args); + for (int i = 0; i < num_args; ++i) { + init_scalars[i] = init_literals[i]->Get({}); + } - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](tensorflow::gtl::ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); + for (int64 input = 0; input < num_args; ++input) { + TF_RETURN_IF_ERROR(results[input]->Populate( + [&](absl::Span multi_index) { + if (!eval_status.ok()) { + return init_scalars[input]; + } + absl::InlinedVector result_values(init_scalars.begin(), + init_scalars.end()); + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literals[0]->shape()) && + IsScalarAdd(function)) { + CHECK_EQ(num_args, 1); + double computed_result = 0; + auto func = [&](absl::Span input_index) { + computed_result += + GetAsDouble(*arg_literals[0], input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literals[0]->shape(), base, + arg_dim_counts, arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = + [&](absl::Span input_index) -> StatusOr { + absl::InlinedVector arg_values(num_args); + for (int64 i = 0; i < num_args; ++i) { + arg_values[i] = arg_literals[i]->Get(input_index); + } + + // Evaluate computation with specified literal operands. + absl::InlinedVector, 1> + embedded_operands; + for (ReturnT value : result_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + for (ReturnT value : arg_values) { + embedded_operands.push_back( + LiteralUtil::CreateR0(value)); + } + absl::InlinedVector embedded_operands_ptrs( + embedded_operands.size()); + std::transform(embedded_operands.begin(), embedded_operands.end(), + embedded_operands_ptrs.begin(), + [](const std::unique_ptr& ptr) { + return ptr.get(); + }); + + TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, + embedded_evaluator.Evaluate( + *function, embedded_operands_ptrs)); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + if (!has_tuple_output) { + result_values[0] = computed_result->Get({}); + } else { + for (int64 i = 0; i < num_args; ++i) { + result_values[i] = computed_result->Get( + /*multi_index=*/{}, /*shape_index=*/{i}); + } + } return true; }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](tensorflow::gtl::ArraySlice input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = LiteralUtil::CreateR0(curr_val); - auto result_val_literal = - LiteralUtil::CreateR0(result_val); - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_shape, base, arg_dim_counts, arg_dim_steps, func); + return result_values[input]; + })); + } + if (!has_tuple_output) { + parent_->evaluated_[reduce] = std::move(results[0]); + } else { + auto tuple_result = absl::make_unique(reduce->shape()); + for (int64 i = 0; i < num_args; ++i) { + TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + } + parent_->evaluated_[reduce] = std::move(tuple_result); + } + return eval_status; } bool IsScalarAdd(HloComputation* computation) { @@ -1599,13 +1759,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = MakeUnique(select_and_scatter->shape()); + auto result = absl::make_unique(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { - return init_scalar; - })); + [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; for (const auto& window_dimension : window.dimensions()) { @@ -1643,8 +1801,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // 2. Using the selected index, scatter value from `source` to result. We // do this by iterating through the window, and compare each index with // the selected index. - tensorflow::gtl::optional selected_val; - tensorflow::gtl::optional> selected_index; + absl::optional selected_val; + absl::optional> selected_index; IterateThroughWindow( window_shape, window, operand_literal.shape(), source_index, @@ -1691,7 +1849,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); } }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + } while ( + IndexUtil::BumpIndices(source->shape(), absl::MakeSpan(source_index))); parent_->evaluated_[select_and_scatter] = std::move(result); return Status::OK(); @@ -1735,10 +1894,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique(reduce_window->shape()); + auto result = absl::make_unique(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice output_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1802,7 +1961,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_scatter_dim = - !c_binary_search(dim_numbers.update_window_dims(), i); + !absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_scatter_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1821,7 +1980,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_window_dim = - c_binary_search(dim_numbers.update_window_dims(), i); + absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_window_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1848,7 +2007,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { update_dim_is_scatter_dims_.push_back( - !c_binary_search(dim_numbers_.update_window_dims(), i)); + !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { @@ -1883,13 +2042,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // index_vector_index_ and index_vector on every invocation, we reuse the // same storage for all invocations. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); PropagateIndexVectorToInputIndex(); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } private: @@ -1898,7 +2057,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // update the dim_numbers.index_vector_dim() dimension -- that's the // dimension we iterate over in FetchIndexVector. void PropagateUpdateIndexScatterDimsToIndexVectorIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { int64 index_vector_index_i = 0; for (int64 i = 0, e = update_index.size(); i < e; i++) { if (!update_dim_is_scatter_dims_[i]) { @@ -1953,7 +2112,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // The index vector fetched from scatter_indices_. std::vector index_vector_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; @@ -1978,7 +2137,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector window_index_to_update_index; int64 update_index_count = 0; for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.update_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { window_index_to_update_index.push_back(update_index_count++); } else { update_index_count++; @@ -1987,7 +2146,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_dim_value_to_update_index_.push_back(-1); } else { input_dim_value_to_update_index_.push_back( @@ -2006,11 +2165,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // scatter input index on every invocation we reuse the same storage for the // result (input_index_), mutating it in place. // - // This returns an arrayslice into memory owned by the class. - StatusOr> operator()( - tensorflow::gtl::ArraySlice update_index) { + // This returns a Span into memory owned by the class. + StatusOr> operator()( + absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); - return tensorflow::gtl::ArraySlice(input_index_); + return absl::Span(input_index_); } // Returns for a given 'input_dim' the corresponding update dimension index, @@ -2023,7 +2182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Propagates window dimensions from the update index to input_index_ by // mutating input_index_ in place. void PropagateUpdateIndexWindowDimsToInputIndex( - tensorflow::gtl::ArraySlice update_index) { + absl::Span update_index) { for (int64 i = 0, e = input_index_.size(); i < e; i++) { if (input_dim_value_to_update_index_[i] != -1) { input_index_[i] = update_index[input_dim_value_to_update_index_[i]]; @@ -2039,7 +2198,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // PropagateUpdateIndexWindowDimsToInputIndex. std::vector input_dim_value_to_update_index_; - // The result computed by this functor. operator() returns an ArraySlice + // The result computed by this functor. operator() returns a Span // into this vector. std::vector input_index_; }; @@ -2082,12 +2241,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::unique_ptr result = operand.CloneToUnique(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = - [&](tensorflow::gtl::ArraySlice update_window_index, - tensorflow::gtl::ArraySlice input_scatter_index, - tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_window_index, + absl::Span input_scatter_index, + absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_window_index, + absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); for (int i = 0, e = update_index.size(); i < e; i++) { update_index[i] = update_scatter_index[i] + update_window_index[i]; @@ -2136,14 +2294,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { }; auto scatter_outer_loop_body = - [&](tensorflow::gtl::ArraySlice update_scatter_index) - -> StatusOr { + [&](absl::Span update_scatter_index) -> StatusOr { TF_ASSIGN_OR_RETURN( - tensorflow::gtl::ArraySlice input_scatter_index, + absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( updates_shape, window_indices_iteration_space, - [&](tensorflow::gtl::ArraySlice update_window_index) { + [&](absl::Span update_window_index) { return scatter_inner_loop_body( update_window_index, input_scatter_index, update_scatter_index); })); @@ -2171,7 +2328,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 rank = ShapeUtil::Rank(operand->shape()); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](tensorflow::gtl::ArraySlice out_index) { + auto func = [&](absl::Span out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { operand_index[i] = @@ -2387,11 +2544,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value || std::is_same::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = MakeUnique(iota->shape()); - auto data = result->data(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + std::vector data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result->Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template & window_count_index, + const absl::Span& window_count_index, const std::function&)>& f) { const int64 rank = ShapeUtil::Rank(base_shape); DimensionVector window_index(rank); @@ -2451,7 +2618,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!out_of_bound) { f(base_index); } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } while ( + IndexUtil::BumpIndices(window_shape, absl::MakeSpan(window_index))); } template @@ -2470,9 +2638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = MakeUnique(result_shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + auto result = absl::make_unique(result_shape); + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2503,7 +2671,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector result_index(rank, 0); - auto func = [&](tensorflow::gtl::ArraySlice update_index) { + auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); result->Set(result_index, @@ -2548,18 +2716,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2584,20 +2751,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape()), + ShapeUtil::HumanString(ehs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { + TF_RETURN_IF_ERROR( + result->Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index c3ccbf0f0c75b569b49652807dea52faebdccc31..de3d7a167752f0de790585e50874dd6d2904bd37 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -49,7 +51,7 @@ std::unique_ptr CreateHloProfilePrinterData( size_t profile_counters_size = hlo_profile_index_map.total_count(); std::unique_ptr profile_printer_data = - MakeUnique(); + absl::make_unique(); profile_printer_data->set_profile_counters_size(profile_counters_size); profile_printer_data->mutable_computation_infos()->Reserve( hlo_profile_index_map.computation_count()); @@ -67,11 +69,11 @@ std::unique_ptr CreateHloProfilePrinterData( // The profile indices were computed deterministically in // HloProfileIndexMap::HloProfileIndexMap. - c_sort(computation_and_profile_idx_list, - [](const std::pair& left, - const std::pair& right) { - return left.second < right.second; - }); + absl::c_sort(computation_and_profile_idx_list, + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); for (const auto& pair : computation_and_profile_idx_list) { CHECK_LT(pair.second, profile_counters_size); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index eba80c0f199f6224f4b46ac19af482c713585154..460ae2b5eca78659f86df1227e6a0a4e57508611 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::AllOf; using ::testing::ContainsRegex; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1efa6eb5bda7e1cb90874e0466aafd2c788a3fbf..3041d94fa9f55b1acffc1295d07e48c967322865 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -26,6 +26,12 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -37,50 +43,25 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" -using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; -using ::tensorflow::io::JoinPath; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::StringReplace; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { namespace hlo_graph_dumper { namespace { -// Helpers for Printf and Appendf. -template -struct PrintfConvert { - const T& operator()(const T& t) const { return t; } -}; -template <> -struct PrintfConvert { - const char* operator()(const string& s) const { return s.c_str(); } -}; - -// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() -// on strings. -template -string Printf(const char* fmt, const Ts&... ts) { - return tensorflow::strings::Printf(fmt, PrintfConvert()(ts)...); -} -template -void Appendf(string* s, const char* fmt, const Ts&... ts) { - tensorflow::strings::Appendf(s, fmt, PrintfConvert()(ts)...); -} +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; +using tensorflow::Env; +using tensorflow::WriteStringToFile; +using tensorflow::io::JoinPath; // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? @@ -209,17 +190,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) { string NodeColorAttributes(ColorScheme color) { NodeColors node_colors = NodeColorsForScheme(color); - return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, node_colors.stroke_color, - node_colors.fill_color); + return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. -string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", - ">", /*replace_all=*/true); +string HtmlLikeStringSanitize(absl::string_view s) { + return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } // Tries to generates a human-readable one-word description of the given @@ -322,11 +301,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + HloDotDumper(const HloComputation* computation, absl::string_view label, const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(std::string(label)), + label_(label), debug_options_(debug_options), show_backend_config_(show_backend_config), profile_(profile), @@ -448,7 +427,7 @@ string HloDotDumper::Dump() { } string HloDotDumper::Header() { - const char* fmt = R"(digraph G { + constexpr char fmt[] = R"(digraph G { rankdir = TB; compound = true; label = <%s>; @@ -457,7 +436,7 @@ labelloc = t; tooltip = " "; // DOT graphs accept a stylesheet as a URI. So naturally, an inline // stylesheet is a data URI! -stylesheet=" +stylesheet=< data:text/css, @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); svg text { @@ -466,7 +445,7 @@ stylesheet=" } %s -" +> )"; @@ -481,8 +460,8 @@ stylesheet=" } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); - Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles)); + absl::StrAppendFormat(&graph_label, "
total cycles = %d (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); } // Create CSS rules that say, when you hover over the given node or cluster, @@ -509,14 +488,14 @@ stylesheet=" // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( - Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" - " #%s%d:hover ~ #edge%lld path { " - "stroke: %s; stroke-width: .2em; }\n" - " #%s%d:hover ~ #edge%lld polygon { " - "fill: %s; stroke: %s; stroke-width: .2em; }\n", - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, color)); + StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n" + " #%s%d:hover ~ #edge%d path { " + "stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%d polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); }; // The "to_node" value may be a NULL, indicating that this points to the @@ -559,10 +538,10 @@ stylesheet=" } } - return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); + return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } -string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } +string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { CHECK_EQ(instr->opcode(), HloOpcode::kFusion); @@ -600,9 +579,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = + constexpr char edge_fmt[] = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( + edges_.push_back(StrFormat( edge_fmt, InstructionId(from), InstructionId(parent_instr), SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } @@ -619,9 +598,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, string subcomp_label, style; if (parent_instr->opcode() == HloOpcode::kFusion) { - subcomp_label = Printf("Fused expression for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(parent_instr->ToCategory())); + subcomp_label = + StrFormat("Fused expression for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); string extra_info = GetInstructionNodeExtraInfo(parent_instr); if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); @@ -647,18 +627,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; } style = - Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", - fillcolor, strokecolor); + StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", + fillcolor, strokecolor); } else { - subcomp_label = Printf("Subcomputation for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(subcomp->name())); + subcomp_label = StrFormat("Subcomputation for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); style = "style=rounded; color=black;"; } string comp_body = DumpComputation(subcomp); - const char* computation_fmt = R"(subgraph %s { + constexpr char computation_fmt[] = R"(subgraph %s { %s label = <%s>; labelloc = t; @@ -667,7 +647,7 @@ tooltip = " "; } // %s )"; - return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -718,11 +698,11 @@ string HloDotDumper::DumpRootTag() { VLOG(2) << "Adding edge from " << from->name() << " to root tag as " << next_edge_id_; edge_ids_.insert({{from, to}, next_edge_id_++}); - edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); + edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" - "\n", - to_id, node_body, node_shape, NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + to_id, node_body, node_shape, NodeColorAttributes(color)); } static const HloConstantInstruction* TryGetFusionParameterConstant( @@ -817,10 +797,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" - "\n", - InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, node_metadata, + NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -833,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. if (ShapeUtil::IsZeroElementArray(shape)) { - return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. @@ -848,19 +828,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return Printf("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s (%s)", constant->literal().ToString(), + ShapeUtil::HumanString(constant->shape())); } // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { + if (absl::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); } - return Printf("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; std::vector lines; @@ -881,7 +861,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + operand_str = StrFormat("Parameter %d", operand->parameter_number()); } } else { operand_str = operand->name(); @@ -890,13 +870,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand_str) { if (instr->operand_count() > 1) { - lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + lines.push_back(StrFormat("operand %d = %s", i, *operand_str)); } else { - lines.push_back(Printf("operand = %s", *operand_str)); + lines.push_back(StrFormat("operand = %s", *operand_str)); } } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { @@ -1049,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: @@ -1059,7 +1040,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: @@ -1080,14 +1060,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // If we have a parameter, put the param number in the name. if (instr->opcode() == HloOpcode::kParameter) { - return Printf("Parameter %lld", instr->parameter_number()); + return StrFormat("Parameter %d", instr->parameter_number()); } // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::str_util::StartsWith(instr->name(), - HloOpcodeString(instr->opcode()))) { - return Printf("%s", HtmlLikeStringSanitize(instr->name())); + if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { + return StrFormat("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), @@ -1095,8 +1074,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), - HtmlLikeStringSanitize(instr->name())); + return StrFormat("%s
%s", HtmlLikeStringSanitize(extended_opcode), + HtmlLikeStringSanitize(instr->name())); } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { @@ -1105,16 +1084,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); } if (!instr->metadata().op_type().empty()) { - lines.push_back(Printf( + lines.push_back(StrFormat( "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); } if (!instr->metadata().source_file().empty() && instr->metadata().source_line() != 0) { - lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), - instr->metadata().source_line())); + lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(), + instr->metadata().source_line())); } - return Join(lines, "
"); + return StrJoin(lines, "
"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1161,13 +1140,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { constexpr int kMaxShapeLen = 64; if (instr_shape.length() > kMaxShapeLen) { instr_shape = StrCat( - tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), - "..."); + absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } lines.push_back(instr_shape); } if (debug_options_.xla_hlo_graph_addresses()) { - lines.push_back(Printf("[%p]", instr)); + lines.push_back(StrFormat("[%p]", instr)); } if (profile_ != nullptr) { double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); @@ -1175,11 +1153,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { lines.push_back( - Printf("%% of cycles executed=%.2f", - 100 * hlo_cycles_executed / total_cycles_executed)); + StrFormat("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } // Gets the total number of array elements in the given shape. For tuples, this @@ -1211,7 +1189,8 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { string edge_label; if (instr->operand_count() > 1 && !control_edge) { - edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + edge_label = + StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num); } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } @@ -1221,10 +1200,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // means. bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; - edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), from->name(), - to->name(), edge_label)); + constexpr char kEdgeFmt[] = + R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; + edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), + (is_big_array ? "normal" : "empty"), + from->name(), to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1265,14 +1245,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr( continue; } if (instr->called_computations().size() == 1) { - lines.push_back(Printf("Subcomputation: %s", - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation: %s", + HtmlLikeStringSanitize(*computation_type))); } else { - lines.push_back(Printf("Subcomputation %lld: %s", i, - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation %d: %s", i, + HtmlLikeStringSanitize(*computation_type))); } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } const HloInstruction* HloDotDumper::GetNodeForEdge( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1d7a062c55696de9db4b187efd86bce191279083..064c53252c0ac4d4e7b93169ad7cbee4807cb963 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,12 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::HasSubstr; string TestName() { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8690f2cdaa9b45d126e91b123c6992cbe2f27e1d..6d13f85cbbca2ae4b2a794ca5de975fe21e8212e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -21,10 +21,17 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -39,17 +46,15 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( @@ -108,7 +113,7 @@ StatusOr> HloInstruction::CreateFromProto( std::vector fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), - tensorflow::gtl::ArraySlice(fft_length)); + absl::Span(fft_length)); break; } case HloOpcode::kSend: @@ -153,16 +158,26 @@ StatusOr> HloInstruction::CreateFromProto( CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); break; case HloOpcode::kReduce: - TF_RET_CHECK(proto.operand_ids_size() == 2) - << "Reduce instruction should have 2 operands but sees " + TF_RET_CHECK(proto.operand_ids_size() % 2 == 0) + << "Reduce instruction should have an even number of operands but " + "sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Reduce instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - instruction = CreateReduce(proto.shape(), operands(0), operands(1), - std::vector(proto.dimensions().begin(), - proto.dimensions().end()), - computations(0)); + { + const auto reduce_operands = all_operands(); + auto inputs = absl::MakeSpan(reduce_operands) + .subspan(0, reduce_operands.size() / 2); + auto init_values = + absl::MakeSpan(reduce_operands) + .subspan(reduce_operands.size() / 2, reduce_operands.size()); + instruction = + CreateReduce(proto.shape(), inputs, init_values, + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + } break; case HloOpcode::kSort: { TF_RET_CHECK(proto.operand_ids_size() == 1 || @@ -224,7 +239,7 @@ StatusOr> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = MakeUnique(proto.shape()); + instruction = absl::make_unique(proto.shape()); } break; } @@ -281,41 +296,28 @@ StatusOr> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - if (proto.operand_ids_size() == 0) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = CreateInfeed(data_shape, proto.infeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 1); - instruction = - CreateInfeed(data_shape, operands(0), proto.infeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 1); + instruction = + CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - if (proto.operand_ids_size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - proto.outfeed_config()); - } else { - CHECK_EQ(proto.operand_ids_size(), 2); - instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), - operands(1), proto.outfeed_config()); - } + TF_RET_CHECK(proto.operand_ids_size() == 2); + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + operands(1), proto.outfeed_config()); break; case HloOpcode::kCrossReplicaSum: { TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " << proto.called_computation_ids_size(); - tensorflow::gtl::optional all_reduce_id; + absl::optional all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( proto.shape(), all_operands(), computations(0), - /*replica_group_ids=*/ - std::vector(proto.replica_group_ids().begin(), - proto.replica_group_ids().end()), + /*replica_groups=*/ + std::vector(proto.replica_groups().begin(), + proto.replica_groups().end()), /*barrier=*/proto.cross_replica_sum_barrier(), /*all_reduce_id=*/all_reduce_id); break; @@ -325,8 +327,18 @@ StatusOr> HloInstruction::CreateFromProto( proto.shape(), all_operands(), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), - proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier()); + proto.replica_groups().end())); + break; + } + case HloOpcode::kCollectivePermute: { + std::vector> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); break; } case HloOpcode::kConvolution: @@ -335,9 +347,10 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - instruction = - CreateConvolve(proto.shape(), operands(0), operands(1), - proto.window(), proto.convolution_dimension_numbers()); + instruction = CreateConvolve( + proto.shape(), operands(0), operands(1), proto.window(), + proto.convolution_dimension_numbers(), + std::max(static_cast(proto.feature_group_count()), 1LL)); break; case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -372,11 +385,9 @@ StatusOr> HloInstruction::CreateFromProto( ->set_convolution_dimension_numbers( proto.convolution_dimension_numbers()); } - break; - case HloOpcode::kHostCompute: - instruction = - CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), - proto.cost_estimate_ns()); + static_cast(instruction.get()) + ->set_feature_group_count( + std::max(static_cast(proto.feature_group_count()), 1LL)); break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) @@ -391,7 +402,7 @@ StatusOr> HloInstruction::CreateFromProto( << "DynamicSlice instruction should have 2 operands but sees " << proto.operand_ids_size(); std::vector slice_sizes(proto.dynamic_slice_sizes_size()); - c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), slice_sizes); break; @@ -403,14 +414,14 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr gather_dimension_numbers = - MakeUnique(proto.gather_dimension_numbers()); - std::vector gather_window_bounds; - for (int64 bound : proto.gather_window_bounds()) { - gather_window_bounds.push_back(bound); + absl::make_unique( + proto.gather_dimension_numbers()); + std::vector gather_slice_sizes; + for (int64 bound : proto.gather_slice_sizes()) { + gather_slice_sizes.push_back(bound); } - instruction = - CreateGather(proto.shape(), operands(0), operands(1), - *gather_dimension_numbers, gather_window_bounds); + instruction = CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_slice_sizes); break; } case HloOpcode::kScatter: { @@ -422,15 +433,22 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Scatter instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - auto scatter_dimension_numbers = MakeUnique( - proto.scatter_dimension_numbers()); + auto scatter_dimension_numbers = + absl::make_unique( + proto.scatter_dimension_numbers()); instruction = CreateScatter(proto.shape(), operands(0), operands(1), operands(2), computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() <= 1) + << "Iota instruction should have at most 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; default: { - instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) << "No instruction with id " << operand_id; @@ -458,10 +476,11 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->precision_config_ = proto.precision_config(); if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = - MakeUnique(proto.dot_dimension_numbers()); + absl::make_unique(proto.dot_dimension_numbers()); } if (proto.has_sharding()) { @@ -475,44 +494,46 @@ StatusOr> HloInstruction::CreateFromProto( /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - return MakeUnique(parameter_number, shape, name); + return absl::make_unique(parameter_number, shape, + name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - return MakeUnique(tag, operand); + return absl::make_unique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateIota( - const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique(shape, iota_dimension); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - return MakeUnique(shape, operand, index); + return absl::make_unique(shape, operand, + index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) { - return MakeUnique(shape, distribution, parameters); + absl::Span parameters) { + return absl::make_unique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { if (opcode == HloOpcode::kCopy) { // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); } - auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -609,39 +630,41 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK_EQ(HloOpcode::kTuple, opcode); return CreateNary(shape, opcode, operands); } /* static */ std::unique_ptr HloInstruction::CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation) { - return MakeUnique(shape, operands, map_computation); + return absl::make_unique(shape, operands, map_computation); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { - return MakeUnique(shape, lhs, rhs, window, - dimension_numbers); + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { + return absl::make_unique( + shape, lhs, rhs, window, dimension_numbers, feature_group_count); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) { - return MakeUnique(shape, operand, fft_type, fft_length); + absl::Span fft_length) { + return absl::make_unique(shape, operand, fft_type, + fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = - MakeUnique(dimension_numbers); + absl::make_unique(dimension_numbers); return instruction; } @@ -650,10 +673,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_ = + absl::make_unique(); instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); return instruction; @@ -664,60 +689,55 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - return MakeUnique( + return absl::make_unique( shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id) { - return MakeUnique( - shape, operands, reduce_computation, replica_group_ids, barrier, + const std::vector& replica_groups, absl::string_view barrier, + const absl::optional& all_reduce_id) { + return absl::make_unique( + shape, operands, reduce_computation, replica_groups, barrier, all_reduce_id); } /* static */ std::unique_ptr HloInstruction::CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier) { - return MakeUnique(shape, operands, replica_groups, - barrier); + const Shape& shape, absl::Span operands, + const std::vector& replica_groups) { + return absl::make_unique(shape, operands, + replica_groups); } -/* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, HloInstruction* token_operand, - const string& config) { - return MakeUnique(infeed_shape, token_operand, config); +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) { + return absl::make_unique( + shape, operand, source_target_pairs); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( - const Shape& infeed_shape, const string& config) { - return MakeUnique(infeed_shape, config); -} - -/* static */ std::unique_ptr HloInstruction::CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - token_operand, outfeed_config); + const Shape& infeed_shape, HloInstruction* token_operand, + const string& config) { + return absl::make_unique(infeed_shape, token_operand, + config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) { - return MakeUnique(outfeed_shape, operand, - outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config) { + return absl::make_unique( + outfeed_shape, operand, token_operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(operand, token, channel_id, - is_host_transfer); + return absl::make_unique(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( @@ -725,14 +745,15 @@ HloInstruction::CreateCrossReplicaSum( auto send_operand = DynCast(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique(send_operand, is_host_transfer); + return absl::make_unique(send_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique(shape, token, channel_id, - is_host_transfer); + return absl::make_unique(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( @@ -740,19 +761,20 @@ HloInstruction::CreateCrossReplicaSum( auto recv_operand = DynCast(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique(recv_operand, is_host_transfer); + return absl::make_unique(recv_operand, + is_host_transfer); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + absl::Span dimensions) { + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateAfterAll( - tensorflow::gtl::ArraySlice operands) { + absl::Span operands) { CHECK(!operands.empty()); - auto instruction = WrapUnique( + auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); @@ -761,14 +783,15 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr HloInstruction::CreateToken() { - return WrapUnique( + return absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } /* static */ std::unique_ptr HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. instruction->called_computations_.push_back(body); @@ -781,7 +804,7 @@ HloInstruction::CreateCrossReplicaSum( HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); @@ -795,18 +818,17 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - return MakeUnique(shape, operand, start_indices, - limit_indices, strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { + return absl::make_unique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - return MakeUnique(shape, operand, start_indices, - slice_sizes); + absl::Span slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr @@ -814,8 +836,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(update); instruction->AppendOperand(start_indices); @@ -823,14 +845,16 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, } /* static */ std::unique_ptr HloInstruction::CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) { - return MakeUnique(shape, operands, dimension); + return absl::make_unique(shape, operands, + dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( const Shape& shape, HloInstruction* operand) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -839,38 +863,38 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction::CreateBitcastConvert(const Shape& shape, HloInstruction* operand) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); instruction->AppendOperand(operand); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloReduceInstruction( + auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); return std::move(instruction); } /* static */ std::unique_ptr HloInstruction::CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) { std::vector all_args; all_args.reserve(operands.size() * 2); all_args.insert(all_args.end(), operands.begin(), operands.end()); all_args.insert(all_args.end(), init_values.begin(), init_values.end()); - return MakeUnique(shape, all_args, dimensions_to_reduce, - reduce_computation); + return absl::make_unique( + shape, all_args, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - return MakeUnique(shape, operand, init_value, - window, reduce_computation); + return absl::make_unique( + shape, operand, init_value, window, reduce_computation); } /* static */ std::unique_ptr @@ -879,7 +903,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, epsilon, feature_index); } @@ -888,7 +912,7 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return MakeUnique( + return absl::make_unique( shape, operand, scale, offset, mean, variance, epsilon, feature_index); } @@ -898,9 +922,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return MakeUnique(shape, operand, scale, mean, - variance, grad_output, epsilon, - feature_index); + return absl::make_unique( + shape, operand, scale, mean, variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -908,15 +932,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - return MakeUnique( + return absl::make_unique( shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return MakeUnique(shape, operand, - broadcast_dimensions); + absl::Span broadcast_dimensions) { + return absl::make_unique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -974,8 +998,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - return MakeUnique(shape, operand, padding_value, - padding_config); + return absl::make_unique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -984,34 +1008,36 @@ HloInstruction::CreateBroadcastSequence( ShapeUtil::ElementsIn(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; } /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) { - return MakeUnique(shape, operand, dimensions); + absl::Span dimensions) { + return absl::make_unique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, HloInstruction* values) { - return MakeUnique(shape, dimension, keys, values); + return absl::make_unique(shape, dimension, keys, values); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return MakeUnique(shape, fusion_kind, fused_root); + return absl::make_unique(shape, fusion_kind, + fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) { - return MakeUnique(shape, fusion_kind, operands, - fusion_computation); + return absl::make_unique(shape, fusion_kind, operands, + fusion_computation); } void HloInstruction::set_single_sharding(const HloSharding& sharding) { @@ -1031,6 +1057,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1043,7 +1070,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: - case HloOpcode::kHostCompute: return true; case HloOpcode::kCrossReplicaSum: return all_reduce_id().has_value(); @@ -1066,10 +1092,10 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation) { std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1078,21 +1104,14 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { - return MakeUnique(shape, operands, - custom_call_target); -} - -/* static */ std::unique_ptr HloInstruction::CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - return MakeUnique(shape, operands, channel_name, - cost_estimate_ns); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target) { + return absl::make_unique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice elements) { + absl::Span elements) { std::vector element_shapes; for (auto element : elements) { element_shapes.push_back(element->shape()); @@ -1102,11 +1121,11 @@ bool HloInstruction::HasSideEffect() const { } /* static */ std::unique_ptr HloInstruction::CreateGather( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - return MakeUnique(shape, operand, gather_indices, - gather_dim_numbers, window_bounds); + absl::Span slice_sizes) { + return absl::make_unique( + shape, operand, start_indices, gather_dim_numbers, slice_sizes); } /* static */ std::unique_ptr HloInstruction::CreateScatter( @@ -1114,16 +1133,17 @@ bool HloInstruction::HasSideEffect() const { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers) { - return MakeUnique(shape, operand, scatter_indices, - updates, update_computation, - scatter_dim_numbers); + return absl::make_unique( + shape, operand, scatter_indices, updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); instruction->operand_side_metadata_ = std::move(operand_side_metadata); instruction->user_side_metadata_ = std::move(user_side_metadata); instruction->AppendOperand(operand); @@ -1131,8 +1151,7 @@ bool HloInstruction::HasSideEffect() const { } std::unique_ptr HloInstruction::CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; @@ -1171,13 +1190,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kSort: @@ -1299,6 +1318,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } break; } + // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); clone->set_raw_backend_config_string(backend_config_); @@ -1364,7 +1384,7 @@ std::unique_ptr HloInstruction::Clone( // If names ends with .suffix[0-9]+ then replace with a suffix with the // numeric value incremented. int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { clone->name_ = StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); } else { @@ -1482,7 +1502,7 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { } void HloInstruction::RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices) { + absl::Span ascending_indices) { if (ascending_indices.empty()) { return; } @@ -1639,11 +1659,11 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: @@ -1837,7 +1857,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) { string HloInstruction::SignatureString() const { string operands = - Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1857,7 +1877,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { } bool HloInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { switch (opcode_) { // Unary elementwise operations. case HloOpcode::kAbs: @@ -1978,13 +1998,13 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - tensorflow::gtl::ArraySlice slice(operands_); + absl::Span slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; if (options.compact_operands() && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) { // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { StrAppend(out, "null "); @@ -2004,7 +2024,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } - StrAppend(out, Join(str, " ")); + StrAppend(out, StrJoin(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -2021,6 +2041,11 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(DotDimensionNumbersToString()); } + string precision_config_string = PrecisionConfigToString(); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2046,11 +2071,11 @@ std::vector HloInstruction::ExtraAttributesToString( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -2083,12 +2108,12 @@ std::vector HloInstruction::ExtraAttributesToString( break; default: if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=\n", - Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, computation->ToString(new_options)); - }))); + extra.push_back(StrCat( + "calls=\n", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); } break; } @@ -2099,11 +2124,11 @@ std::vector HloInstruction::ExtraAttributesToString( } if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", - Join(control_predecessors_, ", ", - [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); - }), + StrJoin(control_predecessors_, ", ", + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); + }), "}")); } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { @@ -2117,10 +2142,10 @@ std::vector HloInstruction::ExtraAttributesToString( string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", - Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - StrAppend(out, "%", operand->name()); - }), + StrJoin(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, "%", operand->name()); + }), ")"); } @@ -2142,6 +2167,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); + *proto.mutable_precision_config() = precision_config_; if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); @@ -2180,7 +2206,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { +bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { return false; @@ -2286,6 +2312,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCrossReplicaSum(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2354,8 +2382,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); - case HloOpcode::kHostCompute: - return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2394,15 +2420,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " "please file a bug for XLA.", - HloOpcodeString(opcode_).c_str()); + HloOpcodeString(opcode_)); } // Explicit instantiations. template Status HloInstruction::Visit(DfsHloVisitor* visitor); template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); -using DFSStack = - tensorflow::gtl::InlinedVector, 16>; +using DFSStack = absl::InlinedVector, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. @@ -2478,7 +2503,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } @@ -2487,7 +2512,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } } @@ -2647,7 +2672,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - return IsElementwiseImpl(tensorflow::gtl::nullopt); + return IsElementwiseImpl(absl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -2735,10 +2760,13 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { case HloOpcode::kTranspose: return UseKind::kUsePermutingElements; case HloOpcode::kPad: - case HloOpcode::kReduce: // Pad reuses the padding value but not the padded array elements. - // Reduce reuses the init value but not the operand array elements. return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements; + case HloOpcode::kReduce: + // Reduce reuses the init values but not the operand array elements. + return i >= Cast(this)->input_count() + ? UseKind::kReuse + : UseKind::kUsePermutingElements; case HloOpcode::kFusion: // Uses the memoizing, recursive computation defined above. return FusionReusesParamElements::Compute(i, *fused_expression_root()); @@ -2803,7 +2831,7 @@ StatusOr StringToFusionKind( if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } - return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); + return InvalidArgument("Unknown fusion kind: %s", kind_name); } string PaddingConfigToString(const PaddingConfig& padding) { @@ -2812,7 +2840,7 @@ string PaddingConfigToString(const PaddingConfig& padding) { [](const PaddingConfig::PaddingConfigDimension& dim) { return dim.interior_padding() != 0; }); - return Join( + return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { StrAppend( @@ -2836,11 +2864,15 @@ string OpMetadataToString(const OpMetadata& metadata) { if (metadata.source_line() != 0) { result.push_back(StrCat("source_line=", metadata.source_line())); } - return Join(result, " "); + return StrJoin(result, " "); } string RandomDistributionToString(const RandomDistribution& distribution) { - return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); + return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); +} + +string PrecisionToString(const PrecisionConfigProto::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2868,8 +2900,8 @@ string ConvolutionDimensionNumbersToString( output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", - Join(output_dims, "")); + return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->", + StrJoin(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -2880,19 +2912,21 @@ string HloInstruction::DotDimensionNumbersToString() const { const DotDimensionNumbers& dnums = *dot_dimension_numbers_; if (!dnums.lhs_batch_dimensions().empty()) { result.push_back(StrCat("lhs_batch_dims={", - Join(dnums.lhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("lhs_contracting_dims={", - Join(dnums.lhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); if (!dnums.rhs_batch_dimensions().empty()) { result.push_back(StrCat("rhs_batch_dims={", - Join(dnums.rhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("rhs_contracting_dims={", - Join(dnums.rhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); - return Join(result, ", "); + return StrJoin(result, ", "); } StatusOr StringToRandomDistribution(const string& name) { @@ -2906,7 +2940,44 @@ StatusOr StringToRandomDistribution(const string& name) { } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +string HloInstruction::PrecisionConfigToString() const { + if (precision_config_.operand_precision().empty()) { + return ""; + } + return StrCat( + "operand_precision={", + StrJoin(precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfigProto::Precision_IsValid(precision)) + << precision; + StrAppend(out, PrecisionToString( + static_cast( + precision))); + }), + "}"); +} + +StatusOr StringToPrecision( + const string& name) { + static std::unordered_map* map = [] { + static auto* map = + new std::unordered_map; + for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { + if (PrecisionConfigProto::Precision_IsValid(i)) { + auto value = static_cast(i); + (*map)[PrecisionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -3156,31 +3227,25 @@ const string& HloInstruction::outfeed_config() const { return Cast(this)->outfeed_config(); } -const std::vector& HloInstruction::replica_group_ids() const { - return Cast(this)->replica_group_ids(); +const std::vector& HloInstruction::replica_groups() const { + return Cast(this)->replica_groups(); } -const std::vector& HloInstruction::replica_groups() const { - return Cast(this)->replica_groups(); +const std::vector>& +HloInstruction::source_target_pairs() const { + return Cast(this)->source_target_pairs(); } string HloInstruction::cross_replica_sum_barrier() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast(this)->cross_replica_sum_barrier(); - } - return Cast(this)->cross_replica_sum_barrier(); + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); - } - return Cast(this)->set_cross_replica_sum_barrier( + return Cast(this)->set_cross_replica_sum_barrier( barrier); } -tensorflow::gtl::optional HloInstruction::all_reduce_id() const { +absl::optional HloInstruction::all_reduce_id() const { return Cast(this)->all_reduce_id(); } @@ -3206,6 +3271,18 @@ void HloInstruction::set_convolution_dimension_numbers( } } +int64 HloInstruction::feature_group_count() const { + if (auto convolution = DynCast(this)) { + return convolution->feature_group_count(); + } + return Cast(this)->feature_group_count(); +} + +void HloInstruction::set_feature_group_count(int64 feature_group_count) { + Cast(this)->set_feature_group_count( + feature_group_count); +} + HloComputation* HloInstruction::select() const { return Cast(this)->select(); } @@ -3226,10 +3303,6 @@ const string& HloInstruction::custom_call_target() const { return Cast(this)->custom_call_target(); } -const string& HloInstruction::channel_name() const { - return Cast(this)->channel_name(); -} - const PaddingConfig& HloInstruction::padding_config() const { return Cast(this)->padding_config(); } @@ -3246,9 +3319,8 @@ const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { return Cast(this)->gather_dimension_numbers(); } -tensorflow::gtl::ArraySlice HloInstruction::gather_window_bounds() - const { - return Cast(this)->gather_window_bounds(); +absl::Span HloInstruction::gather_slice_sizes() const { + return Cast(this)->gather_slice_sizes(); } const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3c575ae6ea8e60f48def4debcd9cfbea63e396b2..cca134e8b45f89a1c395c791029ee68eeec3c8f0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,11 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -45,10 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -101,6 +103,7 @@ class HloPrintOptions { return HloPrintOptions() .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) + .set_print_backend_config(false) .set_compact_operands(true) .set_print_operand_shape(true) .set_print_program_shape(false) @@ -182,7 +185,7 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } - bool print_backend_config() const { return print_metadata_; } + bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -220,7 +223,7 @@ class CanonicalNameMap { return iter->second; } - string new_name = tensorflow::strings::StrCat("tmp_", index++); + string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } @@ -347,7 +350,8 @@ class HloInstruction { std::unique_ptr literal); // Creates an Iota instruction. - static std::unique_ptr CreateIota(const Shape& shape); + static std::unique_ptr CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( @@ -361,7 +365,7 @@ class HloInstruction { // random numbers from a given distribution. static std::unique_ptr CreateRng( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + absl::Span parameters); // Creates a unary instruction (one operand). // Precondition: opcode must be a legitimate unary operation. @@ -388,13 +392,13 @@ class HloInstruction { // Precondition: opcode must be a legitimate variadic operation. static std::unique_ptr CreateVariadic( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates a map instruction, where the computation (given by the handle) is // applied element-wise to every element in operands (across the operands, // at a given index) static std::unique_ptr CreateMap( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* map_computation); // Creates a convolution op, where rhs is the convolutional filter @@ -402,12 +406,13 @@ class HloInstruction { static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. @@ -432,9 +437,10 @@ class HloInstruction { // // `reduction_computation`: the reduction function. // - // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). + // Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // `all_reduce_id`: for Allreduce nodes from different modules, if they have @@ -443,11 +449,10 @@ class HloInstruction { // // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id); + const std::vector& replica_groups, + absl::string_view barrier, const absl::optional& all_reduce_id); // This op handles the communication of an Alltoall operation. On each core, // the operands are N ops in the same shape, where N is the number of cores @@ -462,12 +467,18 @@ class HloInstruction { // within replica 1, 2, 3, and in the gather phase, the received blocks will // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. - // - // TODO(b/110096724): This is NOT YET ready to use. static std::unique_ptr CreateAllToAll( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier); + const Shape& shape, absl::Span operands, + const std::vector& replica_groups); + + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -486,24 +497,13 @@ class HloInstruction { static std::unique_ptr CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of infeed are - // converted to take tokens. - static std::unique_ptr CreateInfeed(const Shape& infeed_shape, - const string& config); // Creates an outfeed instruction, which outputs data. outfeed_shape is the // shape of the data being outfed *not* the shape of the outfeed instruction // which is a TOKEN. static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); - // Overload which does not require a token. - // TODO(b/80000000): Remove this overload when all uses of outfeed are - // converted to take tokens. - static std::unique_ptr CreateOutfeed( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -536,17 +536,15 @@ class HloInstruction { // start/limit indices. static std::unique_ptr CreateSlice( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, absl::Span strides); // Creates a slice instruction, where the first operand is sliced by // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + HloInstruction* start_indices, absl::Span slice_sizes); // Creates a dynamic update slice instruction, which updates a slice // of 'operand' with 'update' and 'start_indices'. @@ -557,7 +555,7 @@ class HloInstruction { // Creates a concatenate instruction, where the operands are concatenated on // the provided dimension. static std::unique_ptr CreateConcatenate( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension); // Creates a reduce instruction, where the computation (given by the handle) @@ -569,7 +567,7 @@ class HloInstruction { // f(f(init, value0), value1), ...) static std::unique_ptr CreateReduce( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // A more general, multiple-argument version of the above. @@ -584,9 +582,9 @@ class HloInstruction { // ... // TODO(b/112040122): Add support to this in HLO passes and in backends. static std::unique_ptr CreateReduce( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice init_values, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span operands, + absl::Span init_values, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation); // Creates a reduce-window instruction, where the computation (given @@ -623,7 +621,7 @@ class HloInstruction { // Creates a broadcast instruction. static std::unique_ptr CreateBroadcast( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. @@ -653,7 +651,7 @@ class HloInstruction { // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a sort op, with a keys operand, and an optional values operand. static std::unique_ptr CreateSort( @@ -677,9 +675,9 @@ class HloInstruction { static std::unique_ptr CreateGather( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, @@ -703,43 +701,37 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation); // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* computation); // Creates a custom call instruction that applies the given custom call target // to the given operands. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); - - // Creates a HostCompute instruction, which records host-side control and - // data dependencies for use in instruction scheduling. - static std::unique_ptr CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( - tensorflow::gtl::ArraySlice elements); + absl::Span elements); // Creates a reverse instruction, which reverses the order of the elements // in the specified dimensions. static std::unique_ptr CreateReverse( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Afterall instruction used for joining or creating new values of // token type which thread through side-effecting operations. Operands must // all be tokens, and there must be at least one operand. static std::unique_ptr CreateAfterAll( - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Creates an AfterAll instruction which creates a token type out of thin air // (no operands). This is a separate method from CreateAfterAll to facility @@ -776,7 +768,7 @@ class HloInstruction { int64 operand_count() const { return operands_.size(); } // Returns the vector of operands of this instruction. - using InstructionVector = tensorflow::gtl::InlinedVector; + using InstructionVector = absl::InlinedVector; const InstructionVector& operands() const { return operands_; } // Returns the vector of unique operands, in the same order they are found @@ -873,6 +865,11 @@ class HloInstruction { return false; } + if (!absl::c_equal(precision_config_.operand_precision(), + other.precision_config_.operand_precision())) { + return false; + } + return IdenticalSlowPath(other, eq_computations); } @@ -1040,7 +1037,7 @@ class HloInstruction { // Returns true if this instruction can be legally fused into a fusion // instruction. - bool IsFusable() const; + bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. @@ -1048,21 +1045,26 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } // Returns the sharding unique device, if any. - tensorflow::gtl::optional sharding_unique_device() const { + absl::optional sharding_unique_device() const { if (sharding_ == nullptr) { - return tensorflow::gtl::optional(); + return absl::optional(); } return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = MakeUnique(sharding); + sharding_ = std::make_shared(sharding); + } + void set_sharding(std::shared_ptr sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1098,19 +1100,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // TODO(b/80249101): Remove these methods once HLO scheduling and copy - // insertion are integrated, and we don't need to run a separate pass - // of copy elision anymore. - bool CopyElisionAllowed() const { - CHECK_EQ(HloOpcode::kCopy, opcode_); - return copy_elision_allowed_; - } - - void SetCopyElisionAllowed(bool value) { - CHECK_EQ(HloOpcode::kCopy, opcode_); - copy_elision_allowed_ = value; - } - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1120,6 +1109,9 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + // Returns the dump string of the precision configuration. + string PrecisionConfigToString() const; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1130,8 +1122,7 @@ class HloInstruction { // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). @@ -1263,6 +1254,20 @@ class HloInstruction { static StatusOr BackendConfigToRawString( const tensorflow::protobuf::Message& proto); + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfigProto& precision_config() const { + return precision_config_; + } + void set_precision_config(const PrecisionConfigProto& precision_config) { + precision_config_ = precision_config; + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1431,18 +1436,18 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllReduceInstruction::replica_group_ids. - const std::vector& replica_group_ids() const; - - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. - tensorflow::gtl::optional all_reduce_id() const; + absl::optional all_reduce_id() const; // Returns data on the window in a windowed operation such as // convolution. @@ -1466,6 +1471,12 @@ class HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums); + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const; + + void set_feature_group_count(int64 feature_group_count); + // Delegates to HloSelectAndScatterInstruction::select. HloComputation* select() const; @@ -1481,9 +1492,6 @@ class HloInstruction { // Delegates to HloCustomCallInstruction::custom_call_target. const string& custom_call_target() const; - // Delegates to HloHostComputeInstruction::channel_name. - const string& channel_name() const; - // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; @@ -1495,8 +1503,8 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_dimension_numbers. const GatherDimensionNumbers& gather_dimension_numbers() const; - // Delegates to HloGatherInstruction::gather_window_bounds. - tensorflow::gtl::ArraySlice gather_window_bounds() const; + // Delegates to HloGatherInstruction::gather_slice_sizes. + absl::Span gather_slice_sizes() const; // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; @@ -1522,7 +1530,7 @@ class HloInstruction { // Removes a list of operands with the given indices in ascending order. void RemoveOperandsAtAscendingIndices( - tensorflow::gtl::ArraySlice ascending_indices); + absl::Span ascending_indices); void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); @@ -1552,8 +1560,7 @@ class HloInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. virtual std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { // TODO(b/80131774): This should be pure virtual. LOG(FATAL) << "Unimplemented method."; @@ -1571,7 +1578,7 @@ class HloInstruction { // NOTE: For all instructions other than kFusion, being elementwise on one of // the operands is equivalent to being elementwise on all the operands. virtual bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const; + const absl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1599,7 +1606,7 @@ class HloInstruction { // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( const Shape& shape, HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + absl::Span operands); // Adds a user for this instruction. void AddUser(HloInstruction* user); @@ -1648,7 +1655,10 @@ class HloInstruction { bool copy_elision_allowed_ = true; // The sharding, if one exists. - std::unique_ptr sharding_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr sharding_; // Fields used by the kDomain instruction. std::unique_ptr operand_side_metadata_; @@ -1667,6 +1677,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfigProto precision_config_; + // String identifier for instruction. string name_; @@ -1689,10 +1703,12 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string PrecisionToString(const PrecisionConfigProto::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); +StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 8a694dde8066ab9a1138b9f7981153d451ddb89e..76b0e940a656ee2f54781b927fdca367a83056c6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -39,10 +39,8 @@ namespace { using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; -class HloInstructionTest : public HloTestBase { +class HloInstructionTest : public HloVerifiedTestBase { protected: - HloInstructionTest() {} - Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); }; @@ -53,7 +51,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { public: Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("not implemented %s", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } Status HandleParameter(HloInstruction* parameter) override { @@ -1086,16 +1084,14 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { // Fused expression: - // - // x y - // \ / \ - // min broadcast + // y + // / + // x broadcast + // \ / | + // min | // \ / // sub // - // The fusion instruction is elementwise on `x` because the only path from x - // to sub contains only elementwise operations. It is not elementwise on `y` - // because the path y->broadcast->sub is not all elementwise. const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); @@ -1104,10 +1100,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y")); - HloInstruction* min = builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y)); HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0})); + builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {})); + HloInstruction* min = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); @@ -1118,10 +1114,10 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { EXPECT_FALSE(fusion->IsElementwise()); for (int64 operand_idx = 0; operand_idx < fusion->operand_count(); ++operand_idx) { - if (fusion->operand(operand_idx) == x) { - EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); - } else { + if (fusion->operand(operand_idx) == y) { EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx)); + } else { + EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx)); } } } @@ -1248,7 +1244,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); + HloInstruction::CreateBroadcast(data_shape, one, {})); auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape, HloOpcode::kAdd, dot, add_operand)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( @@ -1355,7 +1351,7 @@ TEST_F(HloInstructionTest, Stringification) { TEST_F(HloInstructionTest, StringifyGather_0) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); @@ -1363,19 +1359,18 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1383,15 +1378,15 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=4, window_bounds={30,29,28,27,26}"); + "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyGather_1) { Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); - Shape gather_indices_tensor_shape = + Shape start_indices_tensor_shape = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); Shape gather_result_shape = ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); @@ -1399,19 +1394,18 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloComputation::Builder builder("Gather"); HloInstruction* input = builder.AddInstruction( HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); - HloInstruction* gather_indices = + HloInstruction* start_indices = builder.AddInstruction(HloInstruction::CreateParameter( - 1, gather_indices_tensor_shape, "gather_indices")); - - HloInstruction* gather_instruction = - builder.AddInstruction(HloInstruction::CreateGather( - gather_result_shape, input, gather_indices, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + 1, start_indices_tensor_shape, "start_indices")); + + HloInstruction* gather_instruction = builder.AddInstruction( + HloInstruction::CreateGather(gather_result_shape, input, start_indices, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*slice_sizes=*/{30, 29, 28, 27, 26})); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1419,10 +1413,10 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " - "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " - "gather_dims_to_operand_dims={0,1,2,3,4}, " - "index_vector_dim=2, window_bounds={30,29,28,27,26}"); + "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " + "start_index_map={0,1,2,3,4}, " + "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); } TEST_F(HloInstructionTest, StringifyScatter) { @@ -1745,5 +1739,23 @@ TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { << clone->convolution_dimension_numbers().DebugString(); } +TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { + constexpr char kHloString[] = R"( + HloModule test_module + ENTRY test { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, + dim_labels=b0f_0io->b0f, operand_precision={high,default} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kHloString)); + auto* conv = module->entry_computation()->root_instruction(); + + auto clone = conv->Clone(); + EXPECT_THAT(clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfigProto::HIGH, + PrecisionConfigProto::DEFAULT)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 1de5032670ff47cda5599cf736bbd3529cfcaba9..e46afa764f519c9f7b6e3e9a8a37c84bd173b9a2 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,12 @@ limitations under the License. #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -27,10 +33,10 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -85,11 +91,10 @@ HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( std::unique_ptr HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); } @@ -107,11 +112,10 @@ HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( std::unique_ptr HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -129,18 +133,17 @@ HloBatchNormGradInstruction::HloBatchNormGradInstruction( std::unique_ptr HloBatchNormGradInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } -HloFftInstruction::HloFftInstruction( - const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length) +HloFftInstruction::HloFftInstruction(const Shape& shape, + HloInstruction* operand, FftType fft_type, + absl::Span fft_length) : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { fft_length_.assign(fft_length.begin(), fft_length.end()); AppendOperand(operand); @@ -158,7 +161,7 @@ HloInstructionProto HloFftInstruction::ToProto() const { std::vector HloFftInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {StrCat("fft_type=", FftType_Name(fft_type())), - StrCat("fft_length={", Join(fft_length(), ","), "}")}; + StrCat("fft_length={", StrJoin(fft_length(), ","), "}")}; } bool HloFftInstruction::IdenticalSlowPath( @@ -171,12 +174,11 @@ bool HloFftInstruction::IdenticalSlowPath( } std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], fft_type_, - fft_length_); + return absl::make_unique(shape, new_operands[0], fft_type_, + fft_length_); } HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, @@ -226,12 +228,11 @@ HloSendInstruction::HloSendInstruction(HloInstruction* operand, } std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(new_operands[0], new_operands[1], - channel_id(), is_host_transfer()); + return absl::make_unique( + new_operands[0], new_operands[1], channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, @@ -244,11 +245,10 @@ HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, std::unique_ptr HloSendDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } @@ -265,11 +265,10 @@ HloRecvInstruction::HloRecvInstruction(const Shape& shape, } std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -287,35 +286,69 @@ HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand, std::unique_ptr HloRecvDoneInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( Cast(new_operands[0]), is_host_transfer()); } -HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), - all_reduce_id_(all_reduce_id) { +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + const std::vector& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } - AppendComputation(reduce_computation); } -HloInstructionProto HloAllReduceInstruction::ToProto() const { +HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - for (int64 i : replica_group_ids_) { - proto.add_replica_group_ids(i); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + return proto; +} + +std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); } + result.push_back( + StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return absl::c_equal(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return absl::c_equal(a.replica_ids(), b.replica_ids()); + }); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, absl::Span operands, + HloComputation* reduce_computation, + const std::vector& replica_groups, absl::string_view barrier, + const absl::optional& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier), + all_reduce_id_(all_reduce_id) { + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); @@ -325,9 +358,9 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { } std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { - std::vector result = { - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -342,7 +375,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return replica_group_ids() == casted_other.replica_group_ids() && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -351,78 +384,80 @@ bool HloAllReduceInstruction::IdenticalSlowPath( std::unique_ptr HloAllReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( - shape, new_operands, to_apply(), replica_group_ids(), + return absl::make_unique( + shape, new_operands, to_apply(), replica_groups(), cross_replica_sum_barrier(), all_reduce_id()); } HloAllToAllInstruction::HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier) - : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -bool HloAllToAllInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const { - const auto& casted_other = static_cast(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier(); -} + const Shape& shape, absl::Span operands, + const std::vector& replica_groups) + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} std::unique_ptr HloAllToAllInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* /*context*/) const { - return MakeUnique( - shape, new_operands, replica_groups(), cross_replica_sum_barrier()); + return absl::make_unique(shape, new_operands, + replica_groups()); } -std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector result; - std::vector replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); } + return proto; +} +std::vector +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); return result; } -HloInstructionProto HloAllToAllInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_replica_groups() = {replica_groups_.begin(), - replica_groups_.end()}; - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); - return proto; +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return absl::c_equal(source_target_pairs(), + casted_other.source_target_pairs(), + [](const std::pair& a, + const std::pair& b) { return a == b; }); } -HloReverseInstruction::HloReverseInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) +std::unique_ptr +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], source_target_pairs()); +} + +HloReverseInstruction::HloReverseInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span dimensions) : HloInstruction(HloOpcode::kReverse, shape), dimensions_(dimensions.begin(), dimensions.end()) { AppendOperand(operand); @@ -438,7 +473,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const { std::vector HloReverseInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReverseInstruction::IdenticalSlowPath( @@ -450,16 +485,15 @@ bool HloReverseInstruction::IdenticalSlowPath( } std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloConcatenateInstruction::HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, int64 dimension) : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { for (auto operand : operands) { @@ -477,7 +511,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const { std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloConcatenateInstruction::IdenticalSlowPath( @@ -491,16 +525,15 @@ bool HloConcatenateInstruction::IdenticalSlowPath( std::unique_ptr HloConcatenateInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, - dimensions(0)); + return absl::make_unique(shape, new_operands, + dimensions(0)); } HloReduceInstruction::HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + const Shape& shape, absl::Span args, + absl::Span dimensions_to_reduce, HloComputation* reduce_computation) : HloInstruction(HloOpcode::kReduce, shape), dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { @@ -520,7 +553,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const { std::vector HloReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReduceInstruction::IdenticalSlowPath( @@ -535,12 +568,11 @@ bool HloReduceInstruction::IdenticalSlowPath( } std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands, dimensions(), - to_apply()); + CHECK_EQ(new_operands.size() % 2, 0); + return absl::make_unique(shape, new_operands, + dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -563,7 +595,7 @@ HloInstructionProto HloSortInstruction::ToProto() const { std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloSortInstruction::IdenticalSlowPath( @@ -575,17 +607,17 @@ bool HloSortInstruction::IdenticalSlowPath( } std::unique_ptr HloSortInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; - return MakeUnique(shape, dimensions(0), keys, values); + return absl::make_unique(shape, dimensions(0), keys, + values); } HloTransposeInstruction::HloTransposeInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions) + absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { CHECK_EQ(shape.dimensions().size(), dimensions.size()); @@ -595,7 +627,7 @@ HloTransposeInstruction::HloTransposeInstruction( Permute(dimensions, shape.dimensions()).begin())) << "shape: " << ShapeUtil::HumanString(shape) << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; + << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -616,7 +648,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const { std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloTransposeInstruction::IdenticalSlowPath( @@ -629,17 +661,16 @@ bool HloTransposeInstruction::IdenticalSlowPath( std::unique_ptr HloTransposeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } HloBroadcastInstruction::HloBroadcastInstruction( const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension) + absl::Span broadcast_dimension) : HloInstruction(HloOpcode::kBroadcast, shape), dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { AppendOperand(operand); @@ -655,7 +686,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const { std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloBroadcastInstruction::IdenticalSlowPath( @@ -668,17 +699,16 @@ bool HloBroadcastInstruction::IdenticalSlowPath( std::unique_ptr HloBroadcastInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - dimensions()); + return absl::make_unique(shape, new_operands[0], + dimensions()); } -HloMapInstruction::HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation) +HloMapInstruction::HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation) : HloInstruction(HloOpcode::kMap, shape) { for (auto operand : operands) { AppendOperand(operand); @@ -699,7 +729,7 @@ HloInstructionProto HloMapInstruction::ToProto() const { } bool HloMapInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { if (!dimensions().empty()) { // Check that the map is executed in elementwise compatible dimensions. if (dimensions().size() != shape().dimensions_size()) { @@ -716,7 +746,7 @@ bool HloMapInstruction::IsElementwiseImpl( std::vector HloMapInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloMapInstruction::IdenticalSlowPath( @@ -727,17 +757,16 @@ bool HloMapInstruction::IdenticalSlowPath( } std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, new_operands, to_apply()); + return absl::make_unique(shape, new_operands, to_apply()); } -HloSliceInstruction::HloSliceInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) +HloSliceInstruction::HloSliceInstruction(const Shape& shape, + HloInstruction* operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) : HloInstruction(HloOpcode::kSlice, shape), slice_starts_(start_indices.begin(), start_indices.end()), slice_limits_(limit_indices.begin(), limit_indices.end()), @@ -774,7 +803,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); } - return {StrCat("slice={", Join(bounds, ", "), "}")}; + return {StrCat("slice={", StrJoin(bounds, ", "), "}")}; } bool HloSliceInstruction::IdenticalSlowPath( @@ -788,12 +817,11 @@ bool HloSliceInstruction::IdenticalSlowPath( } std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], slice_starts_, - slice_limits_, slice_strides_); + return absl::make_unique( + shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) @@ -812,7 +840,7 @@ HloInstructionProto HloConstantInstruction::ToProto() const { } bool HloConstantInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { return true; } @@ -842,10 +870,9 @@ bool HloConstantInstruction::IdenticalSlowPath( std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(literal_->CloneToUnique()); + return absl::make_unique(literal_->CloneToUnique()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -860,7 +887,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); + std::vector v = absl::StrSplit(tmp, ' '); bool first = true; // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. @@ -900,8 +927,7 @@ bool HloTraceInstruction::IdenticalSlowPath( } std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); } @@ -919,7 +945,7 @@ HloFusionInstruction::HloFusionInstruction(const Shape& shape, HloFusionInstruction::HloFusionInstruction( const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, + absl::Span operands, HloComputation* fusion_computation) : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { for (auto operand : operands) { @@ -952,7 +978,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const { } bool HloFusionInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { if (!operand_idx.has_value()) { for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { @@ -1155,7 +1181,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal( HloInstruction* HloFusionInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations().empty()) { @@ -1326,8 +1352,7 @@ bool HloFusionInstruction::IdenticalSlowPath( } std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { HloModule* module = context != nullptr ? context->module() : GetModule(); HloComputation* new_fused_computation = nullptr; @@ -1339,8 +1364,8 @@ std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation = module->AddEmbeddedComputation( fused_instructions_computation()->Clone("clone", context)); } - return MakeUnique(shape, fusion_kind(), new_operands, - new_fused_computation); + return absl::make_unique( + shape, fusion_kind(), new_operands, new_fused_computation); } Status HloFusionInstruction::DeduplicateFusionOperands() { @@ -1365,7 +1390,7 @@ Status HloFusionInstruction::DeduplicateFusionOperands() { HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters) + absl::Span parameters) : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { for (HloInstruction* param : parameters) { AppendOperand(param); @@ -1384,7 +1409,7 @@ std::vector HloRngInstruction::ExtraAttributesToStringImpl( } bool HloRngInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const { + const absl::optional& operand_idx) const { return true; } @@ -1396,10 +1421,10 @@ bool HloRngInstruction::IdenticalSlowPath( } std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(shape, distribution_, new_operands); + return absl::make_unique(shape, distribution_, + new_operands); } HloParameterInstruction::HloParameterInstruction(int64 parameter_number, @@ -1432,10 +1457,10 @@ bool HloParameterInstruction::IdenticalSlowPath( std::unique_ptr HloParameterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return MakeUnique(parameter_number_, shape, name()); + return absl::make_unique(parameter_number_, shape, + name()); } HloGetTupleElementInstruction::HloGetTupleElementInstruction( @@ -1467,12 +1492,11 @@ bool HloGetTupleElementInstruction::IdenticalSlowPath( std::unique_ptr HloGetTupleElementInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique(shape, new_operands[0], - tuple_index()); + return absl::make_unique( + shape, new_operands[0], tuple_index()); } HloReducePrecisionInstruction::HloReducePrecisionInstruction( @@ -1510,11 +1534,10 @@ bool HloReducePrecisionInstruction::IdenticalSlowPath( std::unique_ptr HloReducePrecisionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], exponent_bits(), mantissa_bits()); } @@ -1528,13 +1551,6 @@ HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, AppendOperand(token_operand); } -HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape, - const string& config) - : HloInstruction(HloOpcode::kInfeed, - ShapeUtil::MakeTupleShape( - {infeed_shape, ShapeUtil::MakeTokenShape()})), - infeed_config_(config) {} - HloInstructionProto HloInfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_infeed_config(infeed_config_); @@ -1558,24 +1574,20 @@ bool HloInfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - if (new_operands.empty()) { - return MakeUnique(infeed_shape(), infeed_config()); - } else { - CHECK_EQ(new_operands.size(), 1); - return MakeUnique(infeed_shape(), new_operands[0], - infeed_config()); - } + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + infeed_shape(), new_operands[0], infeed_config()); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) +HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + absl::string_view outfeed_config) : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + outfeed_config_(outfeed_config) { CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) << "Outfeed shape " << outfeed_shape << " must be compatible with operand shape " << operand->shape(); @@ -1583,18 +1595,6 @@ HloOutfeedInstruction::HloOutfeedInstruction( AppendOperand(token_operand); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - tensorflow::StringPiece outfeed_config) - : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), - outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); - AppendOperand(operand); -} - HloInstructionProto HloOutfeedInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_outfeed_config(outfeed_config()); @@ -1619,25 +1619,21 @@ bool HloOutfeedInstruction::IdenticalSlowPath( } std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - if (new_operands.size() == 1) { - return MakeUnique(outfeed_shape(), new_operands[0], - outfeed_config()); - } else { - CHECK_EQ(new_operands.size(), 2); - return MakeUnique(outfeed_shape(), new_operands[0], - new_operands[1], outfeed_config()); - } + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + outfeed_shape(), new_operands[0], new_operands[1], outfeed_config()); } HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) : HloInstruction(HloOpcode::kConvolution, shape), window_(window), - convolution_dimension_numbers_(dimension_numbers) { + convolution_dimension_numbers_(dimension_numbers), + feature_group_count_(feature_group_count) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1664,6 +1660,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_window() = window_; *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1675,6 +1672,7 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); return extra; } @@ -1684,6 +1682,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( eq_computations) const { const auto& casted_other = static_cast(other); + if (feature_group_count_ != other.feature_group_count()) { + return false; + } return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), @@ -1692,13 +1693,12 @@ bool HloConvolutionInstruction::IdenticalSlowPath( std::unique_ptr HloConvolutionInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], - new_operands[1], window(), - convolution_dimension_numbers_); + return absl::make_unique( + shape, new_operands[0], new_operands[1], window(), + convolution_dimension_numbers_, feature_group_count_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -1737,11 +1737,10 @@ bool HloReduceWindowInstruction::IdenticalSlowPath( std::unique_ptr HloReduceWindowInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], window(), to_apply()); } @@ -1786,21 +1785,20 @@ bool HloSelectAndScatterInstruction::IdenticalSlowPath( std::unique_ptr HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], select(), window(), new_operands[1], new_operands[2], scatter()); } HloCustomCallInstruction::HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), - custom_call_target_(custom_call_target.begin(), - custom_call_target.end()) { + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); } @@ -1816,6 +1814,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1830,6 +1829,9 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( "dim_labels=", ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } // By contract, we print the custom call target even if // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. @@ -1857,60 +1859,28 @@ bool HloCustomCallInstruction::IdenticalSlowPath( casted_other.convolution_dimension_numbers()))) { return false; } + if (feature_group_count_ != casted_other.feature_group_count_) { + return false; + } return custom_call_target_ == casted_other.custom_call_target_; } std::unique_ptr HloCustomCallInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - auto cloned = MakeUnique(shape, new_operands, - custom_call_target()); + auto cloned = absl::make_unique( + shape, new_operands, custom_call_target()); if (window_ != nullptr) { cloned->set_window(*window_); } if (convolution_dimension_numbers_ != nullptr) { cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); } + cloned->set_feature_group_count(feature_group_count_); return std::move(cloned); } -HloHostComputeInstruction::HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) - : HloInstruction(HloOpcode::kHostCompute, shape), - channel_name_(channel_name.begin(), channel_name.end()), - cost_estimate_ns_(cost_estimate_ns) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -HloInstructionProto HloHostComputeInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; -} - -bool HloHostComputeInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const { - // Not yet supported. - return false; -} - -std::unique_ptr -HloHostComputeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, - HloCloneContext* context) const { - return MakeUnique( - shape, new_operands, channel_name_, cost_estimate_ns_); -} - HloPadInstruction::HloPadInstruction(const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, @@ -1941,17 +1911,16 @@ bool HloPadInstruction::IdenticalSlowPath( } std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique(shape, new_operands[0], new_operands[1], - padding_config_); + return absl::make_unique(shape, new_operands[0], + new_operands[1], padding_config_); } HloDynamicSliceInstruction::HloDynamicSliceInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kDynamicSlice, shape), dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { AppendOperand(operand); @@ -1968,8 +1937,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const { std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return { - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; + return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","), + "}")}; } bool HloDynamicSliceInstruction::IdenticalSlowPath( @@ -1981,60 +1950,57 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath( std::unique_ptr HloDynamicSliceInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } HloGatherInstruction::HloGatherInstruction( - const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) + absl::Span slice_sizes) : HloInstruction(HloOpcode::kGather, shape) { AppendOperand(operand); - AppendOperand(gather_indices); + AppendOperand(start_indices); gather_dimension_numbers_ = - MakeUnique(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); + absl::make_unique(gather_dim_numbers); + absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string offset_dims = + StrCat("offset_dims={", + StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = StrCat( + "collapsed_slice_dims={", + StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + string start_index_map = + StrCat("start_index_map={", + StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - return Join>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, + return StrJoin>( + {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim) { + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim) { GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); + for (int64 output_window_dim : offset_dims) { + gather_dim_numbers.add_offset_dims(output_window_dim); } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); + for (int64 elided_window_dim : collapsed_slice_dims) { + gather_dim_numbers.add_collapsed_slice_dims(elided_window_dim); } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + for (int64 gather_dim_to_input_dim : start_index_map) { + gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } gather_dim_numbers.set_index_vector_dim(index_vector_dim); @@ -2044,8 +2010,8 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { HloInstructionProto HloGatherInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); + for (int64 bound : gather_slice_sizes()) { + proto.add_gather_slice_sizes(bound); } return proto; } @@ -2053,7 +2019,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2064,17 +2030,16 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_window_bounds() == casted_other.gather_window_bounds(); + gather_slice_sizes() == casted_other.gather_slice_sizes(); } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_window_bounds()); + gather_slice_sizes()); } HloScatterInstruction::HloScatterInstruction( @@ -2088,24 +2053,24 @@ HloScatterInstruction::HloScatterInstruction( AppendOperand(updates); AppendComputation(update_computation); scatter_dimension_numbers_ = - MakeUnique(scatter_dim_numbers); + absl::make_unique(scatter_dim_numbers); } string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = - StrCat("update_window_dims={", - Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string update_window_dims = StrCat( + "update_window_dims={", + StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); - return Join>( + return StrJoin>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim}, ", "); @@ -2113,9 +2078,9 @@ string HloScatterInstruction::ScatterDimensionNumbersToString() const { /* static */ ScatterDimensionNumbers HloScatterInstruction::MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim) { ScatterDimensionNumbers scatter_dim_numbers; for (int64 update_window_dim : update_window_dims) { @@ -2155,13 +2120,41 @@ bool HloScatterInstruction::IdenticalSlowPath( } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique( + return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + return absl::make_unique(shape, iota_dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 9586ad667345111d05015e035c93fe6578e3b665..323038357993c4e9b99d1527aa8f593ada92f1c8 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -66,8 +67,7 @@ class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -81,8 +81,7 @@ class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -96,8 +95,7 @@ class HloBatchNormGradInstruction : public HloBatchNormInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -105,7 +103,7 @@ class HloFftInstruction : public HloInstruction { public: explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + absl::Span fft_length); FftType fft_type() const { return fft_type_; } const std::vector& fft_length() const { return fft_length_; } @@ -123,8 +121,7 @@ class HloFftInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes FFT type for an FFT instruction. @@ -173,8 +170,7 @@ class HloSendInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -186,8 +182,7 @@ class HloSendDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -199,8 +194,7 @@ class HloRecvInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; @@ -212,24 +206,41 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { private: // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; }; -class HloAllReduceInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloInstruction { + public: + const std::vector& replica_groups() const { + return replica_groups_; + } + + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + const std::vector& replica_groups); + + HloInstructionProto ToProto() const override; + + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + std::vector replica_groups_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, absl::Span operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional& all_reduce_id); - - // Returns the group ids of each replica for CrossReplicaSum op. - const std::vector& replica_group_ids() const { - return replica_group_ids_; - } + const std::vector& replica_groups, + absl::string_view barrier, const absl::optional& all_reduce_id); // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. @@ -240,9 +251,7 @@ class HloAllReduceInstruction : public HloInstruction { cross_replica_sum_barrier_ = barrier; } - tensorflow::gtl::optional all_reduce_id() const { - return all_reduce_id_; - } + absl::optional all_reduce_id() const { return all_reduce_id_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -257,41 +266,42 @@ class HloAllReduceInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // The group id of each replica for CrossReplicaSum. - std::vector replica_group_ids_; - // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. - tensorflow::gtl::optional all_reduce_id_; + absl::optional all_reduce_id_; }; -class HloAllToAllInstruction : public HloInstruction { +class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operand, - const std::vector& replica_groups, - tensorflow::StringPiece barrier); + const Shape& shape, absl::Span operands, + const std::vector& replica_groups); - const std::vector& replica_groups() const { - return replica_groups_; - } + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; +}; - // TODO(b/110096724): rename this. - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); + + const std::vector>& source_target_pairs() const { + return source_target_pairs_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: @@ -304,20 +314,16 @@ class HloAllToAllInstruction : public HloInstruction { // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - std::vector replica_groups_; - - // The string representation of the barrier config. - string cross_replica_sum_barrier_; + const std::vector> source_target_pairs_; }; class HloReverseInstruction : public HloInstruction { public: explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -333,8 +339,7 @@ class HloReverseInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -342,9 +347,9 @@ class HloReverseInstruction : public HloInstruction { class HloConcatenateInstruction : public HloInstruction { public: - explicit HloConcatenateInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - int64 dimension); + explicit HloConcatenateInstruction(const Shape& shape, + absl::Span operands, + int64 dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -362,8 +367,7 @@ class HloConcatenateInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -371,26 +375,28 @@ class HloConcatenateInstruction : public HloInstruction { class HloReduceInstruction : public HloInstruction { public: - explicit HloReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice args, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reduce_computation); + explicit HloReduceInstruction(const Shape& shape, + absl::Span args, + absl::Span dimensions_to_reduce, + HloComputation* reduce_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns the number of input arrays (and, consequentially, the number of + // init values) this reduce has. + int64 input_count() const { return operand_count() / 2; } + // Returns the input tensors to be reduced. - tensorflow::gtl::ArraySlice inputs() const { - return tensorflow::gtl::ArraySlice(operands(), 0, - operand_count() / 2); + absl::Span inputs() const { + return absl::MakeSpan(operands()).subspan(0, input_count()); } // Returns the init values of the reduction. - tensorflow::gtl::ArraySlice init_values() const { - return tensorflow::gtl::ArraySlice( - operands(), operand_count() / 2, operand_count()); + absl::Span init_values() const { + return absl::MakeSpan(operands()).subspan(input_count(), operand_count()); } private: @@ -402,8 +408,7 @@ class HloReduceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -431,8 +436,7 @@ class HloSortInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -440,9 +444,8 @@ class HloSortInstruction : public HloInstruction { class HloTransposeInstruction : public HloInstruction { public: - explicit HloTransposeInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice dimensions); + explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand, + absl::Span dimensions); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -460,8 +463,7 @@ class HloTransposeInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -469,9 +471,8 @@ class HloTransposeInstruction : public HloInstruction { class HloBroadcastInstruction : public HloInstruction { public: - explicit HloBroadcastInstruction( - const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice broadcast_dimension); + explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand, + absl::Span broadcast_dimension); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -487,8 +488,7 @@ class HloBroadcastInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -496,9 +496,9 @@ class HloBroadcastInstruction : public HloInstruction { class HloMapInstruction : public HloInstruction { public: - explicit HloMapInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* map_computation); + explicit HloMapInstruction(const Shape& shape, + absl::Span operands, + HloComputation* map_computation); // Returns the dimension sizes or numbers associated with this instruction. const std::vector& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } @@ -507,7 +507,7 @@ class HloMapInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -516,8 +516,7 @@ class HloMapInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::vector dimensions_; @@ -526,9 +525,9 @@ class HloMapInstruction : public HloInstruction { class HloSliceInstruction : public HloInstruction { public: explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); HloInstructionProto ToProto() const override; @@ -567,8 +566,7 @@ class HloSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [begin, end) index range for a slice. @@ -600,7 +598,7 @@ class HloConstantInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function& @@ -610,8 +608,7 @@ class HloConstantInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // TODO(b/36360764): Remove unique_ptr wrapping. std::unique_ptr literal_; @@ -632,8 +629,7 @@ class HloTraceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // TODO(b/36360764): Remove unique_ptr wrapping. std::unique_ptr literal_; @@ -644,10 +640,9 @@ class HloFusionInstruction : public HloInstruction { explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); - explicit HloFusionInstruction( - const Shape& shape, FusionKind fusion_kind, - tensorflow::gtl::ArraySlice operands, - HloComputation* fusion_computation); + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + absl::Span operands, + HloComputation* fusion_computation); string ToCategory() const override; // Returns a serialized representation of this instruction. @@ -751,7 +746,7 @@ class HloFusionInstruction : public HloInstruction { bool add_output = false); bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -760,8 +755,7 @@ class HloFusionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The type of the fusion. Used by kFusion only. @@ -770,9 +764,9 @@ class HloFusionInstruction : public HloInstruction { class HloRngInstruction : public HloInstruction { public: - explicit HloRngInstruction( - const Shape& shape, RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters); + explicit HloRngInstruction(const Shape& shape, + RandomDistribution distribution, + absl::Span parameters); // Returns the random distribution for this rng node. RandomDistribution random_distribution() const { return distribution_; } // Returns a serialized representation of this instruction. @@ -780,7 +774,7 @@ class HloRngInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional& operand_idx) const override; + const absl::optional& operand_idx) const override; std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -789,8 +783,7 @@ class HloRngInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The distribution requested for random number generation. @@ -815,8 +808,7 @@ class HloParameterInstruction : public HloInstruction { CanonicalNameMap* canonical_name_map) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 parameter_number_ = 0; @@ -840,8 +832,7 @@ class HloGetTupleElementInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; int64 tuple_index_ = -1; @@ -869,8 +860,7 @@ class HloReducePrecisionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The bit sizes for a reduce-precision operation. @@ -883,10 +873,6 @@ class HloInfeedInstruction : public HloInstruction { explicit HloInfeedInstruction(const Shape& infeed_shape, HloInstruction* token_operand, const string& config); - // TODO(b/80000000): Remove this constructor when all uses of infeed are - // converted to take tokens. - explicit HloInfeedInstruction(const Shape& infeed_shape, - const string& config); // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -911,8 +897,7 @@ class HloInfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The string representation of the infeed configuration. @@ -924,13 +909,7 @@ class HloOutfeedInstruction : public HloInstruction { explicit HloOutfeedInstruction(const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, - tensorflow::StringPiece outfeed_config); - // TODO(b/80000000): Remove this constructor when all uses of outfeed are - // converted to take tokens. - explicit HloOutfeedInstruction(const Shape& outfeed_shape, - HloInstruction* operand, - tensorflow::StringPiece outfeed_config); - + absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -950,8 +929,7 @@ class HloOutfeedInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Shape of outfeed request. @@ -965,7 +943,8 @@ class HloConvolutionInstruction : public HloInstruction { explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -975,6 +954,9 @@ class HloConvolutionInstruction : public HloInstruction { const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = dnums; } + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count() const { return feature_group_count_; } string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -988,12 +970,14 @@ class HloConvolutionInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; // Describes the dimension numbers used for a convolution. ConvolutionDimensionNumbers convolution_dimension_numbers_; + // The number of feature groups. Must be a divisor of the input feature + // dimension and output feature dimension. + int64 feature_group_count_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1017,8 +1001,7 @@ class HloReduceWindowInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; @@ -1066,24 +1049,23 @@ class HloSelectAndScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; Window window_; }; class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); + explicit HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; } void set_window(const Window& window) override { - window_ = MakeUnique(window); + window_ = absl::make_unique(window); } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -1094,9 +1076,13 @@ class HloCustomCallInstruction : public HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = - MakeUnique(dnums); + absl::make_unique(dnums); } const string& custom_call_target() const { return custom_call_target_; } + void set_feature_group_count(int64 feature_group_count) { + feature_group_count_ = feature_group_count; + } + int64 feature_group_count() const { return feature_group_count_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1109,8 +1095,7 @@ class HloCustomCallInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1118,33 +1103,8 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr window_; // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; -}; - -class HloHostComputeInstruction : public HloInstruction { - public: - explicit HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - const string& channel_name() const { return channel_name_; } - // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const override; - - private: - bool IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const override; - // Implementation for non-common logic of CloneWithNewOperands. - std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, - HloCloneContext* context) const override; - // Name to use for host send/recv channels. - string channel_name_; - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; + // The number of feature groups. This is used for grouped convolutions. + int64 feature_group_count_; }; class HloPadInstruction : public HloInstruction { @@ -1166,8 +1126,7 @@ class HloPadInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // The padding configuration that describes the edge padding and interior @@ -1177,10 +1136,10 @@ class HloPadInstruction : public HloInstruction { class HloDynamicSliceInstruction : public HloInstruction { public: - explicit HloDynamicSliceInstruction( - const Shape& shape, HloInstruction* operand, - HloInstruction* start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + explicit HloDynamicSliceInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* start_indices, + absl::Span slice_sizes); // Old methods kept for smooth subclassing transition END. // Returns the size of the slice in the given dimension for a dynamic // slice node. @@ -1202,8 +1161,7 @@ class HloDynamicSliceInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; // Describes the [start, start + size) range size for a dynamic slice @@ -1215,15 +1173,15 @@ class HloGatherInstruction : public HloInstruction { public: explicit HloGatherInstruction( const Shape& shape, HloInstruction* operand, - HloInstruction* gather_indices, + HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; } - tensorflow::gtl::ArraySlice gather_window_bounds() const { - return gather_window_bounds_; + absl::Span gather_slice_sizes() const { + return gather_slice_sizes_; } // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; @@ -1232,10 +1190,9 @@ class HloGatherInstruction : public HloInstruction { // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice output_window_dims, - tensorflow::gtl::ArraySlice elided_window_dims, - tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, - int64 index_vector_dim); + absl::Span offset_dims, + absl::Span collapsed_slice_dims, + absl::Span start_index_map, int64 index_vector_dim); private: std::vector ExtraAttributesToStringImpl( @@ -1245,12 +1202,11 @@ class HloGatherInstruction : public HloInstruction { const std::function& eq_computations) const override; std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr gather_dimension_numbers_; - std::vector gather_window_bounds_; + std::vector gather_slice_sizes_; }; class HloScatterInstruction : public HloInstruction { @@ -1271,9 +1227,9 @@ class HloScatterInstruction : public HloInstruction { // Creates an instance of ScatterDimensionNumbers. static ScatterDimensionNumbers MakeScatterDimNumbers( - tensorflow::gtl::ArraySlice update_window_dims, - tensorflow::gtl::ArraySlice inserted_window_dims, - tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + absl::Span update_window_dims, + absl::Span inserted_window_dims, + absl::Span scatter_dims_to_operand_dims, int64 index_vector_dim); private: @@ -1285,13 +1241,35 @@ class HloScatterInstruction : public HloInstruction { eq_computations) const override; // Implementation for non-common logic of CloneWithNewOperands. std::unique_ptr CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice new_operands, + const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; std::unique_ptr scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 8e0d38b6a63917582b8bfa10f205e1ed511efef3..d9be841dd751651ba029998fd062fcaec3691945 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,20 +17,20 @@ limitations under the License. #include +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { - -using ::tensorflow::StringPiece; - namespace { +using absl::string_view; + constexpr int kEOF = -1; constexpr int kError = -2; @@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -tensorflow::StringPiece HloLexer::StringPieceFromPointers( - const char* begin, const char* end) const { +absl::string_view HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::StringPiece(begin, end - begin); + return absl::string_view(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - tensorflow::StringPiece identifier = + absl::string_view identifier = StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. @@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = std::string(identifier); + str_val_ = string(identifier); return TokKind::kIdent; } @@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); return TokKind::kDecimal; } @@ -339,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + if (absl::SimpleAtoi(slice, &int64_val_)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -365,6 +364,7 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no = line_no_cache_.line_no_of_query; } for (; ptr != location; ptr++) { + CHECK_LT(ptr, buf_.end()); if (*ptr == '\n') { line_no++; } @@ -374,24 +374,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == tensorflow::StringPiece::npos) { + if (line_offset == absl::string_view::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { +absl::string_view HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == tensorflow::StringPiece::npos + const char* start = line_start == absl::string_view::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); const char* end = - line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; + line_end == absl::string_view::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -403,10 +403,10 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::StringPiece raw = + absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + if (!absl::CUnescape(raw, &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 003ac34ace5713446afa74eb3af96ae33087223e..3e2f8bcd52f9043f161197756a2060b28dded1d9 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +34,7 @@ namespace xla { // it directly. class HloLexer { public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + explicit HloLexer(absl::string_view buf) : buf_(buf) { current_ptr_ = buf_.begin(); } @@ -77,7 +77,7 @@ class HloLexer { std::pair GetLineAndColumn(LocTy location) const; // Returns the whole line given the location. - tensorflow::StringPiece GetLine(LocTy loc) const; + absl::string_view GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -89,8 +89,8 @@ class HloLexer { // Creates StringPiece with the given begin and end. Exits if the begin > end, // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; + absl::string_view StringPieceFromPointers(const char* begin, + const char* end) const; tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( const char* begin, const char* end) const; @@ -107,11 +107,11 @@ class HloLexer { TokKind LexNumberOrPattern(); TokKind LexString(); - const tensorflow::StringPiece buf_; + const absl::string_view buf_; const char* current_ptr_; // Information about the current token. - const char* token_start_; + const char* token_start_ = nullptr; TokKind current_kind_; string str_val_; Shape shape_val_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 43c41ece6efc4f9e8ca74f16e0f63d29abc4de4e..3a1dd471c626ae9497cfcca62c30736bcdbb2b38 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,17 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { using Worklist = std::deque; using Workset = std::unordered_set; -namespace { - void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { if (workset->count(instruction) == 0) { @@ -296,7 +294,7 @@ StatusOr> HloLivenessAnalysis::Run( VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module)); liveness_analysis->RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 7e4b8834357d39099f76450b849d6b5624e4e3b4..5269cad94d35be3dd1c009588bbe422ff1533364 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -15,15 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { -using ::tensorflow::str_util::Join; - bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong lhs_contracting_dimensions (got {" - << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" - << lhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") + << "} want {" << lhs_contracting_dim_ << "})"; return false; } @@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong rhs_contracting_dimensions (got {" - << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" - << rhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") + << "} want {" << rhs_contracting_dim_ << "})"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c577b4359aae6c66f29860a0e56c3487b07afc02..5502e565b6dfbaca6cfa2101950fb0a68c89771f 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher class HloShardingMatcher : public ::testing::MatcherInterface { public: - explicit HloShardingMatcher( - const tensorflow::gtl::optional& sharding) + explicit HloShardingMatcher(const absl::optional& sharding) : sharding_(sharding) {} bool MatchAndExplain(const HloInstruction* instruction, @@ -129,7 +128,7 @@ class HloShardingMatcher void DescribeTo(std::ostream* os) const override; private: - tensorflow::gtl::optional sharding_; + absl::optional sharding_; }; // Matches a Dot HLO instruction with specific LHS and RHS contracting @@ -189,6 +188,7 @@ HLO_MATCHER(Fusion); HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); +HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); @@ -307,7 +307,7 @@ inline ::testing::Matcher Shape( return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } inline ::testing::Matcher Shape( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -317,7 +317,7 @@ inline ::testing::Matcher ShapeWithLayout( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } inline ::testing::Matcher ShapeWithLayout( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -330,14 +330,14 @@ inline ::testing::Matcher Sharding( } // Matcher for Sharding from sharding string inline ::testing::Matcher Sharding( - tensorflow::StringPiece sharding) { + absl::string_view sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( ParseSharding(sharding).ValueOrDie())); } // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( - new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); + new ::xla::testing::HloShardingMatcher(absl::nullopt)); } inline ::testing::Matcher Dot( diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 55ff073d3faf34aa0f1b8f0886946837e7a49bcc..3a1bc4e328b89d75efde7e7afeb0e52ceed4d8f9 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -22,12 +22,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -274,7 +275,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), module_config); + auto module = absl::make_unique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -352,7 +353,7 @@ bool IsUsedOutsideSubcomputation( } // anonymous namespace HloInstruction* HloModule::OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation) { auto builder = HloComputation::Builder(outlined_computation_name); @@ -409,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { - tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } @@ -507,7 +508,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix, config_); + auto module = absl::make_unique(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); @@ -535,12 +536,11 @@ uint64 HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName( - tensorflow::StringPiece name) { +HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); - auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { - return computation->name() == name; - }); + auto it = absl::c_find_if( + computations_in_module, + [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d2e726a0db63f622cd5092d56b4f746232d04aad..3c3371426b7a6a054053fe6761f87c3b5a097699 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -24,6 +24,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -142,7 +142,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(tensorflow::StringPiece name); + HloComputation* GetComputationWithName(absl::string_view name); // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } @@ -192,7 +192,7 @@ class HloModule { // order (root of outlined instructions last). TODO(jingyue): takes a set of // instructions and topologically sorts them. HloInstruction* OutlineExpressionFromComputation( - tensorflow::gtl::ArraySlice instructions_to_outline, + absl::Span instructions_to_outline, const string& outlined_computation_name, HloComputation* computation); // Returns a randomly generated uint64. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b..9bfa3a5f45c8e810f9ea7d6bdcd72b90254d15b9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,15 +18,15 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrAppend; +using absl::StrAppend; HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, bool ignore_layouts) @@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); + string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } - StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 074e9c90705d432b8344aebaf3c15aeb41a59fa3..3f1e1cc73eeb9debe5eb6278ab192fdf9b8cc10f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -72,15 +72,6 @@ class HloModuleConfig { return debug_options_.xla_hlo_profile(); } - // Sets/returns whether this is a "host module". Host modules are used to - // record the data- and control-flow dependencies of host side computation - // that communicates with compiled code. They are used for analysis and - // scheduling purposes, but no code is generated. - bool is_host_module() const { return is_host_module_; } - void set_is_host_module(bool is_host_module) { - is_host_module_ = is_host_module; - } - // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } @@ -113,7 +104,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; + absl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 29024085c1038961ef2b3721de1ce0e8a55ccf45..12ca2340a6ccaa50780e81168c755c1fec3aa1be 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -31,7 +31,7 @@ namespace xla { class HloModuleDCE : public HloPassInterface { public: ~HloModuleDCE() override {} - tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + absl::string_view name() const override { return "hlo-module-dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 10bf9ffd6c1960df5ca2a3555d120b0874407f15..9c01862a4b7024826c3f701b795819abe945d07f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -19,9 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -59,7 +60,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = MakeUnique(modules); + auto metadata = absl::make_unique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() { if (VLOG_IS_ON(4)) { DumpCollectedStats(); } + + for (HloModule* module : modules_) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + points_to_analyses_[module] = std::move(points_to_analysis); + } + return Status::OK(); } @@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { ss << " " << hlo->name() << std::endl; } ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); + return FailedPrecondition("%s", ss.str()); } } } @@ -204,6 +213,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( return channels_[channel_id_map_.at(channel_id)]; } +bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const { + return channel_id_map_.find(channel_id) != channel_id_map_.end(); +} + HloComputation* HloModuleGroupMetadata::PeerComputation( const HloInstruction* instruction) const { CHECK(IsChannelInstruction(instruction)); @@ -267,15 +280,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } -tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( +absl::optional HloModuleGroupMetadata::GetInstructionDevice( const HloInstruction& instruction) const { // The module group metadata can be created in both "single module, multiple // devices" and "multiple modules, no explicit devices" fashions. // The API returns an optional even though the current implementation always // returns a device, to account for cases where we cannot guess a device. // In such cases the VerifyChannelInstructions() will return proper errors. - tensorflow::gtl::optional device = - instruction.sharding_unique_device(); + absl::optional device = instruction.sharding_unique_device(); if (!device) { device = GetModuleId(instruction.parent()->parent()); } @@ -283,10 +295,7 @@ tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( } int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { - return std::count_if(modules_.begin(), modules_.end(), - [](const HloModule* module) { - return !module->config().is_host_module(); - }); + return modules_.size(); } Status HloModuleGroupMetadata::RecordInstructions() { @@ -383,7 +392,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - tensorflow::MakeUnique>()); + absl::make_unique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -411,16 +420,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, Status HloModuleGroupMetadata::VerifyChannelInstructions() { for (const Channel& channel : channels_) { if (channel.send == nullptr) { - return FailedPrecondition("missing send for id : %lld", channel.id); + return FailedPrecondition("missing send for id : %d", channel.id); } if (channel.recv == nullptr) { - return FailedPrecondition("missing recv for id : %lld", channel.id); + return FailedPrecondition("missing recv for id : %d", channel.id); } if (channel.send_done == nullptr) { - return FailedPrecondition("missing send-done for id : %lld", channel.id); + return FailedPrecondition("missing send-done for id : %d", channel.id); } if (channel.recv_done == nullptr) { - return FailedPrecondition("missing recv-done for id : %lld", channel.id); + return FailedPrecondition("missing recv-done for id : %d", channel.id); } } @@ -436,33 +445,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { auto send_done_device = GetInstructionDevice(*channel.send_done); if (!send_device) { return FailedPrecondition("send instruction must have a device: %s", - channel.send->ToString().c_str()); + channel.send->ToString()); } if (!send_done_device) { return FailedPrecondition("send_done instruction must have a device: %s", - channel.send_done->ToString().c_str()); + channel.send_done->ToString()); } if (*send_device != *send_done_device) { return FailedPrecondition( - "send and send-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "send and send-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *send_device, *send_done_device); } auto recv_device = GetInstructionDevice(*channel.recv); auto recv_done_device = GetInstructionDevice(*channel.recv_done); if (!recv_done_device) { return FailedPrecondition("recv_done instruction must have a device: %s", - channel.recv_done->ToString().c_str()); + channel.recv_done->ToString()); } if (*recv_device != *recv_done_device) { return FailedPrecondition( - "recv and recv-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "recv and recv-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *recv_device, *recv_done_device); } if (*send_device == *recv_device) { return FailedPrecondition( - "send and recv (channel=%lld) must be on different devices: %lld", + "send and recv (channel=%d) must be on different devices: %d", channel.id, *send_device); } } @@ -483,7 +492,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { !CheckCompanionPathsCompatibility( path, GetCompanionsPath(channel.recv_done))) { return FailedPrecondition( - "Nest companion paths do not match for channel %lld", channel.id); + "Nest companion paths do not match for channel %d", channel.id); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 84f2d3f5fbc1a6ff1df8ba3c0babd122e5701148..768b0c7eb3695715de5cef7dad1ed5a110561605 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,14 +22,15 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -125,6 +126,9 @@ class HloModuleGroupMetadata { // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns if the given channel id exists in metadata. + bool HasChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. const std::vector& GetAllReduceGroup( int64 all_reduce_id) const; @@ -156,7 +160,7 @@ class HloModuleGroupMetadata { // Retrieves the device an instruction is assigned to. Either from the // sharding information, or from the ordinal of the module the instruction // is in. - tensorflow::gtl::optional GetInstructionDevice( + absl::optional GetInstructionDevice( const HloInstruction& instruction) const; // Returns the number of modules for devices (excluding the host module). @@ -166,7 +170,7 @@ class HloModuleGroupMetadata { // // Precondition: IsCompanionWhile(instruction) is true. const std::unordered_set& Companions( - HloInstruction* instruction) const { + const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } @@ -194,6 +198,10 @@ class HloModuleGroupMetadata { // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } + TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { + return points_to_analyses_.at(module).get(); + } + private: Status Build(); @@ -243,7 +251,7 @@ class HloModuleGroupMetadata { companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap companion_set_index_; + tensorflow::gtl::FlatMap companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). tensorflow::gtl::FlatMap @@ -268,6 +276,9 @@ class HloModuleGroupMetadata { // The modules that this metadata was built from. const std::vector& modules_; + + tensorflow::gtl::FlatMap> + points_to_analyses_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade153109c6c809c37aa08257f83a82c44d5..d83ee714905252e36f38438e81002a4d6ba7dafa 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,14 +22,17 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +40,38 @@ namespace xla { std::vector HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector predecessors; - - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + std::vector + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -79,12 +96,14 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( add_unique_predecessor(control_predecessor); } } - if (instruction->opcode() == HloOpcode::kRecvDone) { + if (instruction->opcode() == HloOpcode::kRecvDone && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote predecessor of RecvDone. HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_predecessor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // Recv is a remote predecessor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -98,22 +117,37 @@ std::vector HloModuleGroupUtil::GlobalPredecessors( std::vector HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector successors; - - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + std::vector + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet unique; + + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the @@ -140,14 +174,16 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( add_unique_successor(control_successor); } } - if (instruction->opcode() == HloOpcode::kRecv) { + if (instruction->opcode() == HloOpcode::kRecv && + !DynCast(instruction)->is_host_transfer()) { // Send is a remote successor of Recv. const HloInstruction* recv_done = instruction->users().front(); CHECK(recv_done->opcode() == HloOpcode::kRecvDone); HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_successor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast(instruction)->is_host_transfer()) { // RecvDone is a remote successor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -157,7 +193,7 @@ std::vector HloModuleGroupUtil::GlobalSuccessors( } std::vector HloModuleGroupUtil::RootInstructions( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector roots; for (HloComputation* computation : computations) { for (HloInstruction* instruction : computation->instructions()) { @@ -234,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( string cyclic_instructions; for (const auto& state : *visit_state) { if (state.second == VisitState::kVisiting) { - tensorflow::strings::StrAppend(&cyclic_instructions, - state.first->ToString(), "\n"); + absl::StrAppend(&cyclic_instructions, state.first->ToString(), + "\n"); } } // TODO(b/64305524): Improve the error message to print out the @@ -246,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( "following nodes. Note that the order of the nodes is arbitrary " "and that the list may include nodes that are not part of the " "cycle.\n%s", - predecessor->ToString().c_str(), cyclic_instructions.c_str()); + predecessor->ToString(), cyclic_instructions); } stack.push(predecessor); } @@ -257,7 +293,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( } Status HloModuleGroupUtil::VerifyComputations( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { @@ -288,7 +324,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { std::vector post_order; auto visit_function = [&](HloInstruction* instruction, @@ -302,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = MakeUnique(post_order); + auto reachability = absl::make_unique(post_order); for (HloInstruction* hlo : post_order) { reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index c25ca1aff50b288f3ac3885cbed53e7ba9768430..309c23045d1e0dd91e2f245d00c51d9bf9961bf5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -56,7 +56,7 @@ class HloModuleGroupUtil { // Returns the root instructions of the computations. std::vector RootInstructions( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Visit state of each instruction during DFS traversal. enum VisitState { @@ -93,15 +93,14 @@ class HloModuleGroupUtil { HloInstruction* root); // Verifies that the computations are well-formed (e.g., no cycles). - Status VerifyComputations( - tensorflow::gtl::ArraySlice computations); + Status VerifyComputations(absl::Span computations); // Below Reachability utils resemble those in HloComputation, except that // they can handle instructions across multiple computations. // // Creates the reachability map for the instructions in the computations. StatusOr> ComputeReachability( - tensorflow::gtl::ArraySlice computations); + absl::Span computations); // Updates the reachability of the given instruction, taking the global // predeccessorss and successors into account. diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 236f4500860a8673e61cbd2f861a8fc40c7861f7..4bc1bacd7ddd6573e75eb5e2b38b24ff5899d330 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,16 +15,16 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -44,7 +44,7 @@ class HloModuleTest : public HloTestBase { // Creates a computation which calls the given zero-parameter computations. std::unique_ptr CreateCallComputation( - tensorflow::gtl::ArraySlice computations) { + absl::Span computations) { auto builder = HloComputation::Builder("Call"); for (auto computation : computations) { builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d1eaf357855205f1e9867e86f3042b96b6beff97..2d4e38589fe4693e73c46d6c82e51cb0a8388f85 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -39,7 +39,7 @@ StatusOr StringToHloOpcode(const string& opcode_name) { }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { - return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + return InvalidArgument("Unknown opcode: %s", opcode_name); } return it->second; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ec279867e595b66a22882703cc06046e3e916c96..e6bfb8025d4bfeba1d334d1f946e33841a2da092 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ @@ -85,7 +86,6 @@ namespace xla { V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ - V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIota, "iota") \ @@ -156,7 +156,7 @@ enum HloOpcodeProperty { // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); -// Returns a string representation of the opcode. +// Retrieves the opcode enum by name if the opcode exists. StatusOr StringToHloOpcode(const string& opcode_name); inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a..0581d5c40425d332d89cc92ca6c6b0b10dd8fcf1 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore( } // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { + if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), + use.instruction)) { + continue; + } if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of " << a << " (" << use << ") not before " << b << " is defined"; @@ -302,22 +306,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); for (auto* computation : module_->MakeNonfusionComputations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s predecessors:", instruction->name().c_str())); + pieces.push_back( + absl::StrFormat(" %s predecessors:", instruction->name())); for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) @@ -368,8 +370,8 @@ string SequentialHloOrdering::ToString() const { std::vector pieces; pieces.push_back("SequentialHloOrdering"); for (auto* computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); + pieces.push_back( + absl::StrFormat("computation %s order:", computation->name())); // Gather all instructions in the module sequence for this computation and // sort them by their position. std::vector instructions; @@ -384,11 +386,10 @@ string SequentialHloOrdering::ToString() const { return order_position_.at(a) < order_position_.at(b); }); for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", instruction->name())); } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } std::ostream& operator<<( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 4b3cd99dc06520bfeb60430d9d4316db66ea04b3..ea8e6a239a22335b644369a78791029c36315560 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,6 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -24,21 +30,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { -using ::tensorflow::StringPiece; -using ::tensorflow::gtl::optional; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsInts; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; const double kF16max = 65504; @@ -47,7 +49,7 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(StringPiece str, const HloModuleConfig& config) + explicit HloParser(absl::string_view str, const HloModuleConfig& config) : lexer_(str), config_(config) {} // Runs the parser. Returns false if an error occurred. @@ -57,14 +59,29 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return Join(error_, "\n"); } + string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr ParseShardingOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); + StatusOr ParsePaddingConfigOnly(); + + // Stand-alone parsing utility for a single instruction worth of text. + Status ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name); private: + // Locates an instruction with the given name in the instruction_pool_ or + // returns nullptr. + // + // If the missing_instruction_hook_ is registered and a "shape" is provided, + // the hook will be called and may satisfy the request for the given + // instruction. This is useful when we reify parameters as they're resolved; + // i.e. for ParseSingleInstruction. + std::pair* FindInstruction( + const string& name, const optional& shape = nullopt); + // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); @@ -138,6 +155,7 @@ class HloParser { kFusionKind, kDistribution, kDomain, + kPrecisionList, }; struct AttrConfig { @@ -203,6 +221,7 @@ class HloParser { bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); + bool ParsePrecisionList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -221,6 +240,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); + bool ParsePrecision(PrecisionConfigProto::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -233,8 +253,8 @@ class HloParser { bool CanBeParamListToShape(); // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - bool Error(LocTy loc, StringPiece msg); + bool TokenError(absl::string_view msg); + bool Error(LocTy loc, absl::string_view msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -265,24 +285,55 @@ class HloParser { std::vector> computations_; const HloModuleConfig config_; std::vector error_; + + // Function that gets invoked when we try to resolve an instruction + // instruction_pool_ but fail to do so. + std::function*(string, + const optional&)> + missing_instruction_hook_; }; -bool HloParser::Error(LocTy loc, StringPiece msg) { +bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { + for (const auto& split : absl::StrSplit(s, delim)) { + int64 val; + if (!absl::SimpleAtoi(split, &val)) { + return false; + } + out->push_back(val); + } + return true; +} + +// Creates replica groups from the provided nested array. groups[i] represents +// the replica ids for group 'i'. +std::vector CreateReplicaGroups( + absl::Span> groups) { + std::vector replica_groups; + absl::c_transform(groups, std::back_inserter(replica_groups), + [](const std::vector& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + return replica_groups; +} + +bool HloParser::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(std::string(lexer_.GetLine(loc))); + error_lines.emplace_back(lexer_.GetLine(loc)); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(Join(error_lines, "\n")); + error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(StringPiece msg) { +bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } @@ -291,6 +342,17 @@ bool HloParser::Run() { return ParseHloModule(); } +std::pair* HloParser::FindInstruction( + const string& name, const optional& shape) { + std::pair* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); + // Potentially call the missing instruction hook. + if (instr == nullptr && missing_instruction_hook_ != nullptr) { + return missing_instruction_hook_(name, shape); + } + return instr; +} + // ::= 'HloModule' name computations bool HloParser::ParseHloModule() { if (lexer_.GetKind() != TokKind::kw_HloModule) { @@ -304,7 +366,7 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique(name, config_); + module_ = absl::make_unique(name, config_); return ParseComputations(); } @@ -357,7 +419,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = MakeUnique(name); + auto builder = absl::make_unique(name); LocTy shape_loc = nullptr; Shape shape; @@ -370,8 +432,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - std::pair* root_node = - tensorflow::gtl::FindOrNull(instruction_pool_, root_name); + std::pair* root_node = FindInstruction(root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. if (!root_name.empty() && root_node == nullptr) { @@ -469,6 +530,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -498,11 +563,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. @@ -597,31 +666,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional>> tmp_groups; optional to_apply; optional> replica_group_ids; optional barrier; optional all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - attrs["replica_group_ids"] = { - /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "", all_reduce_id)); - } else { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "", - all_reduce_id)); + std::vector replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); } + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, replica_groups, + barrier ? *barrier : "", all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -629,21 +696,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional barrier; attrs["replica_groups"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &tmp_groups}; - attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } std::vector replica_groups; if (tmp_groups) { - c_transform(*tmp_groups, std::back_inserter(replica_groups), - [](const std::vector& ids) { - ReplicaGroup group; - *group.mutable_replica_ids() = {ids.begin(), ids.end()}; - return group; - }); + replica_groups = CreateReplicaGroups(*tmp_groups); + } + instruction = builder->AddInstruction( + HloInstruction::CreateAllToAll(shape, operands, replica_groups)); + break; + } + case HloOpcode::kCollectivePermute: { + optional>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; } - instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( - shape, operands, replica_groups, barrier ? *barrier : "")); + std::vector> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } case HloOpcode::kReshape: { @@ -825,9 +907,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kConvolution: { optional window; optional dnums; + optional feature_group_count; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/true, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -835,8 +920,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!window) { window.emplace(); } + if (!feature_group_count) { + feature_group_count = 1; + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, + feature_group_count.value())); break; } case HloOpcode::kFft: { @@ -909,11 +998,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } instruction = builder->AddInstruction(HloInstruction::CreateReduce( shape, /*operands=*/ - tensorflow::gtl::ArraySlice(operands, 0, - operands.size() / 2), + absl::Span(operands).subspan( + 0, operands.size() / 2), /*init_values=*/ - tensorflow::gtl::ArraySlice( - operands, operands.size() / 2, operands.size()), + absl::Span(operands).subspan( + operands.size() / 2, operands.size()), *dimensions_to_reduce, *reduce_computation)); break; } @@ -1073,7 +1162,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kInfeed: { optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } // We need to know the infeed data shape to construct the infeed @@ -1085,41 +1175,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return Error(lexer_.GetLoc(), "infeed must have a non-empty tuple shape"); } - - if (operands.empty()) { - // TODO(b/80000000): Remove this when all uses of infeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : "")); - } else if (operands.size() == 1) { - instruction = builder->AddInstruction(HloInstruction::CreateInfeed( - ShapeUtil::GetTupleElementShape(shape, 0), operands[0], - config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "infeed must have exactly zero or one operands"); - } + instruction = builder->AddInstruction(HloInstruction::CreateInfeed( + ShapeUtil::GetTupleElementShape(shape, 0), operands[0], + config ? *config : "")); break; } case HloOpcode::kOutfeed: { optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } - if (operands.size() == 1) { - // TODO(b/80000000): Remove this when all uses of outfeed are - // converted to take tokens. - instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - operands[0]->shape(), operands[0], config ? *config : "")); - } else if (operands.size() == 2) { - instruction = builder->AddInstruction( - HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], - operands[1], config ? *config : "")); - } else { - return Error(lexer_.GetLoc(), - "outfeed must have exactly one or two operands"); - } + instruction = builder->AddInstruction( + HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0], + operands[1], config ? *config : "")); break; } case HloOpcode::kRng: { @@ -1189,20 +1259,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } - case HloOpcode::kHostCompute: { - optional channel_name; - optional cost_estimate_ns; - attrs["channel_name"] = {/*required=*/true, AttrTy::kString, - &channel_name}; - attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, - &cost_estimate_ns}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( - shape, operands, *channel_name, *cost_estimate_ns)); - break; - } case HloOpcode::kDot: { optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { @@ -1245,22 +1301,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; - attrs["output_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; - attrs["elided_window_dims"] = { - /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; - attrs["gather_dims_to_operand_dims"] = {/*required=*/true, - AttrTy::kBracedInt64List, - &gather_dims_to_operand_dims}; + optional> offset_dims; + attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, + &offset_dims}; + optional> collapsed_slice_dims; + attrs["collapsed_slice_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims}; + optional> start_index_map; + attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List, + &start_index_map}; optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; - attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, - &window_bounds}; + optional> slice_sizes; + attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, + &slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1269,14 +1324,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, GatherDimensionNumbers dim_numbers = HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*offset_dims=*/*offset_dims, + /*collapsed_slice_dims=*/*collapsed_slice_dims, + /*start_index_map=*/*start_index_map, /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( - shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], - dim_numbers, *window_bounds)); + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + dim_numbers, *slice_sizes)); break; } case HloOpcode::kScatter: { @@ -1359,6 +1414,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } + if (operand_precision) { + PrecisionConfigProto precision_config; + *precision_config.mutable_operand_precision() = {operand_precision->begin(), + operand_precision->end()}; + instruction->set_precision_config(precision_config); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1522,14 +1583,14 @@ bool HloParser::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = MakeUnique( + auto entry_sharding_ptr = absl::make_unique( HloSharding::FromProto(*entry_sharding).ValueOrDie()); - auto exit_sharding_ptr = MakeUnique( + auto exit_sharding_ptr = absl::make_unique( HloSharding::FromProto(*exit_sharding).ValueOrDie()); domain->entry_metadata = - MakeUnique(std::move(entry_sharding_ptr)); + absl::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = - MakeUnique(std::move(exit_sharding_ptr)); + absl::make_unique(std::move(exit_sharding_ptr)); } else { return TokenError(StrCat("unsupported domain kind: ", *kind)); } @@ -1549,11 +1610,9 @@ bool HloParser::ParseInstructionNames( if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } - std::pair* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair* instr = FindInstruction(name); if (!instr) { - return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + return TokenError(StrFormat("instruction '%s' is not defined", name)); } instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); @@ -1782,10 +1841,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim( elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - Join(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { - StrAppend(out, num_elems - 1); - }), + StrJoin(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1795,17 +1854,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kLbrace: { nest_level++; if (nest_level > rank) { - return TokenError(Printf( - "expects nested array in rank %lld, but sees larger", rank)); + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees larger", rank)); } if (nest_level > 1) { elems_seen_per_dim[nest_level - 2]++; if (elems_seen_per_dim[nest_level - 2] > shape.dimensions(nest_level - 2)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees more", + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees more", shape.dimensions(nest_level - 2), - get_index_str(nest_level - 2).c_str())); + get_index_str(nest_level - 2))); } } lexer_.Lex(); @@ -1814,9 +1873,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kRbrace: { nest_level--; if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees %lld", - shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees %d", + shape.dimensions(nest_level), get_index_str(nest_level), elems_seen_per_dim[nest_level])); } elems_seen_per_dim[nest_level] = 0; @@ -1837,15 +1896,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, if (rank > 0) { if (nest_level != rank) { return TokenError( - Printf("expects nested array in rank %lld, but sees %lld", rank, - nest_level)); + absl::StrFormat("expects nested array in rank %d, but sees %d", + rank, nest_level)); } elems_seen_per_dim[rank - 1]++; if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError( - Printf("expects %lld elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); } } if (lexer_.GetKind() == TokKind::kw_true || @@ -1938,7 +1997,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = MakeUnique(shape); + *literal = absl::make_unique(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -1972,7 +2031,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", Join(index, ", "), "]")); + ": [", StrJoin(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -2033,6 +2092,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, // ::= operand (, operand)* // operand ::= (shape)? name bool HloParser::ParseOperands(std::vector* operands) { + CHECK(operands != nullptr); if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { return false; @@ -2043,9 +2103,10 @@ bool HloParser::ParseOperands(std::vector* operands) { do { LocTy loc = lexer_.GetLoc(); string name; + optional shape; if (CanBeShape()) { - Shape shape; - if (!ParseShape(&shape)) { + shape.emplace(); + if (!ParseShape(&shape.value())) { return false; } } @@ -2053,8 +2114,8 @@ bool HloParser::ParseOperands(std::vector* operands) { return false; } std::pair* instruction = - tensorflow::gtl::FindOrNull(instruction_pool_, name); - if (!instruction) { + FindInstruction(name, shape); + if (instruction == nullptr) { return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction->first); @@ -2065,6 +2126,7 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + CHECK(operands != nullptr); LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; @@ -2098,8 +2160,8 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("sub-attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("sub-attribute %s is expected but not seen", + attr_it.first)); } } return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); @@ -2119,8 +2181,8 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("attribute %s is expected but not seen", + attr_it.first)); } } return true; @@ -2136,7 +2198,7 @@ bool HloParser::ParseAttributeHelper( } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return Error(loc, Printf("attribute %s already exists", name.c_str())); + return Error(loc, StrFormat("attribute %s already exists", name)); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { @@ -2146,13 +2208,13 @@ bool HloParser::ParseAttributeHelper( } else { allowed_attrs = StrCat( "Allowed attributes: ", - Join(attrs, ", ", - [&](string* out, const std::pair& kv) { - StrAppend(out, kv.first); - })); + StrJoin(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); } - return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), - allowed_attrs.c_str())); + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -2334,10 +2396,20 @@ bool HloParser::ParseAttributeHelper( case AttrTy::kDomain: { return ParseDomain(static_cast(attr_out_ptr)); } + case AttrTy::kPrecisionList: { + std::vector result; + if (!ParsePrecisionList(&result)) { + return false; + } + static_cast>*>( + attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return Error(loc, Printf("error parsing attribute %s", name.c_str())); + return Error(loc, StrFormat("error parsing attribute %s", name)); } return true; } @@ -2452,20 +2524,24 @@ bool HloParser::ParseConvolutionDimensionNumbers( } string str = lexer_.GetStrVal(); - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { + std::vector split1 = absl::StrSplit(str, "_"); + if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } + std::vector split2 = absl::StrSplit(split1[1], "->"); + if (split2.size() != 2) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + absl::string_view lhs = split1[0]; + absl::string_view rhs = split2[0]; + absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + const tensorflow::int64 rank = lhs.length(); + if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); } @@ -2480,8 +2556,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // lhs { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { + if (!is_unique(string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -2498,14 +2573,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); } } } // rhs { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { + if (!is_unique(string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -2522,14 +2596,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_kernel_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); } } } // output { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { + if (!is_unique(string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -2545,8 +2618,8 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c < '0' + rank && c >= '0') { dnums->set_output_spatial_dimensions(c - '0', i); } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + return TokenError(StrFormat( + "expects [0-%dbf] in output dimension numbers", rank - 1)); } } } @@ -2592,9 +2665,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { } const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return Error(loc, Printf("expects [start:limit:step] or [start:limit], " - "but sees %ld elements.", - range.size())); + return Error(loc, + StrFormat("expects [start:limit:step] or [start:limit], " + "but sees %d elements.", + range.size())); } } while (EatIfPresent(TokKind::kComma)); @@ -2606,6 +2680,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } +// precisionlist ::= start precision_elements end +// precision_elements +// ::= /*empty*/ +// ::= precision_val (delim precision_val)* +bool HloParser::ParsePrecisionList( + std::vector* result) { + auto parse_and_add_item = [&]() { + PrecisionConfigProto::Precision item; + if (!ParsePrecision(&item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2762,14 +2854,13 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return Error(loc, - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 number; if (!ParseInt64(&number)) { - return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } result->push_back(number); return true; @@ -2777,9 +2868,8 @@ bool HloParser::ParseDxD(const string& name, // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { - return Error(loc, - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + if (!SplitToInt64s(str, 'x', result)) { + return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } lexer_.Lex(); return true; @@ -2797,10 +2887,9 @@ bool HloParser::ParseWindowPad( return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, "expects padding_low and padding_high separated by '_'"); @@ -2821,10 +2910,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, "expects padding config pattern like 'low_high_interior' or " @@ -2876,9 +2964,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects opcode but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2892,7 +2979,7 @@ bool HloParser::ParseFftType(FftType* result) { } string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { - return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + return TokenError(StrFormat("expects fft type but sees: %s", val)); } lexer_.Lex(); return true; @@ -2906,9 +2993,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2924,8 +3011,25 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( - Printf("expects random distribution but sees: %s, error: %s", - val.c_str(), status_or_result.status().error_message().c_str())); + StrFormat("expects random distribution but sees: %s, error: %s", val, + status_or_result.status().error_message())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { + VLOG(1) << "ParsePrecision"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToPrecision(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects precision but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -3019,7 +3123,7 @@ StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); @@ -3031,7 +3135,7 @@ StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after window"); @@ -3044,7 +3148,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument( @@ -3053,40 +3157,104 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } +StatusOr HloParser::ParsePaddingConfigOnly() { + lexer_.Lex(); + PaddingConfig padding_config; + if (!ParsePaddingConfig(&padding_config)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after PaddingConfig"); + } + return padding_config; +} + +Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name) { + TF_RET_CHECK(missing_instruction_hook_ == nullptr); + + // The missing instruction hook we register creates the shaped instruction on + // the fly as a parameter and returns it. + int64 parameter_count = 0; + missing_instruction_hook_ = + [this, builder, ¶meter_count]( + string name, + const optional& shape) -> std::pair* { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + StrCat("Operand ", name, + " had no shape in HLO text; cannot create parameter for " + "single-instruction module.")); + return nullptr; + } + HloInstruction* parameter = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_count++, *shape, name)); + instruction_pool_[name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(instruction_pool_, name); + }; + + // Prime the lexer. + lexer_.Lex(); + + // Parse the instruction with the registered hook. + if (!ParseInstruction(builder, root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + return Status::OK(); +} + } // namespace StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config) { + absl::string_view str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError()); } return parser.ConsumeHloModule(); } -StatusOr> ParseHloString( - tensorflow::StringPiece str) { +StatusOr> ParseHloString(absl::string_view str) { HloModuleConfig config; return ParseHloString(str, config); } -StatusOr ParseSharding(tensorflow::StringPiece str) { +StatusOr> ParseHloOpToModule( + absl::string_view str, absl::string_view name) { + HloModuleConfig config; + HloParser parser(str, config); + auto builder = absl::make_unique(string(name)); + string root_name; + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); + std::unique_ptr computation = builder->Build(); + auto module = absl::make_unique(string(name), config); + module->AddEntryComputation(std::move(computation)); + return std::move(module); +} + +StatusOr ParseSharding(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseShardingOnly(); } -StatusOr ParseWindow(tensorflow::StringPiece str) { +StatusOr ParseWindow(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str) { + absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseConvolutionDimensionNumbersOnly(); } +StatusOr ParsePaddingConfig(absl::string_view str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParsePaddingConfigOnly(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd..1882a184da8f09a9626daf7a2bbc531cb6ba6138 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" @@ -32,27 +33,34 @@ namespace xla { // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config); + absl::string_view str, const HloModuleConfig& config); + +// Parses the text for a single HLO operation into an HLO module with a function +// that runs that operation (with the same parameters) as its entry computation. +StatusOr> ParseHloOpToModule( + absl::string_view str, absl::string_view name = "single_op"); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> ParseHloString( - tensorflow::StringPiece str); +StatusOr> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr ParseWindow(tensorflow::StringPiece str); +StatusOr ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str); + absl::string_view str); // ParseHloString sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); + +// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". +StatusOr ParsePaddingConfig(absl::string_view str); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 5990a3d4784750feef2e375492851974214db779..759789437c12d489ee607638e736dfd6a6e1dda1 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -16,17 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { - namespace { -using ::tensorflow::StringPiece; +namespace op = ::xla::testing::opcode_matchers; +using absl::string_view; struct TestData { string test_name; @@ -380,7 +382,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default} } )" @@ -393,7 +395,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 } )" @@ -406,7 +408,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 } )" @@ -752,10 +754,10 @@ ENTRY %sparse_f32_r1 () -> f32[9] { "gather", R"(HloModule StringifyGather -ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1030,8 +1032,8 @@ R"(HloModule gather ENTRY Gather { input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) - gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) - ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} + start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26} } )" @@ -1049,7 +1051,7 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add } )" @@ -1067,7 +1069,7 @@ add { ENTRY CrossReplicaSumWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" @@ -1091,7 +1093,19 @@ R"(HloModule AllToAllWithSubgroups ENTRY AllToAllWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} +} + +)" +}, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } )" @@ -1102,7 +1116,7 @@ ENTRY AllToAllWithSubgroups { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" @@ -1125,8 +1139,8 @@ ENTRY Computation { class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1370,7 +1384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1390,15 +1404,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); ExpectHasSubstr( - ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) + ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); @@ -1712,6 +1725,25 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); } +TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) { + const string original = "0_1x2_3"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) { + const string original = "0_1_0x2_3_4"; + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original)); + EXPECT_EQ(original, PaddingConfigToString(dnums)); +} + +TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) { + TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4")); + // The extra "_0" gets added to the canonical string because the other dim has + // interior padding. + EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums)); +} + TEST_F(HloParserTest, NontupleInfeed) { const string original = R"(HloModule nontuple_infeed: ENTRY nontuple_infeed { @@ -1722,5 +1754,26 @@ ENTRY nontuple_infeed { "infeed must have a non-empty tuple shape"); } +TEST(HloParserSingleOpTest, SingleOp) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " + "f32[2,4]{1,0} %x)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { + const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; + StatusOr> module = ParseHloOpToModule(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT( + module.status().ToString(), + ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_fix.h b/tensorflow/compiler/xla/service/hlo_pass_fix.h index 28194deb0e32252b372a328b006dabaf250fa2c7..791b1a97b0b82edf19ff1588fd8d5d996ac0fef4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_fix.h +++ b/tensorflow/compiler/xla/service/hlo_pass_fix.h @@ -45,7 +45,7 @@ class HloPassFix : public Pass { ++iteration_count; if (iteration_count == limit) { LOG(ERROR) - << "Unexpectedly number of iterations in HLO passes (" + << "Unexpectedly high number of iterations in HLO passes (" << iteration_count << ")\nIf compilation hangs here, please file a bug with XLA."; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index 0cddf8fb8f7589739d1233fa4974ff703211a137..f1ad0f9b0148cb3d5f938e7f5d220d6cb82ea98d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -29,7 +29,7 @@ namespace xla { class HloPassInterface { public: virtual ~HloPassInterface() = default; - virtual tensorflow::StringPiece name() const = 0; + virtual absl::string_view name() const = 0; // Run the pass on the given HLO module. Return whether it modified the // module. diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index d8f1ab916b5c5c500c2d8dcd8605be083f95862a..6e4ed0de626688c0d836d6bc9c619245db8d61dd 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,22 +17,23 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { + +using absl::StrAppend; +using absl::StrCat; + void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; @@ -48,9 +49,9 @@ void DumpModuleProto(const HloModule& module, const string& dump_to, tensorflow::mutex_lock lock(mu); const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - const string mod_name = SanitizeFileName(tensorflow::strings::Printf( - "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, - pipeline_name.c_str(), pass_name.c_str())); + const string mod_name = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, pipeline_name, pass_name)); TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), dump_to, mod_name)); @@ -68,7 +69,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { repeated_field.end()); if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(disabled_passes, ", "); + << absl::StrJoin(disabled_passes, ", "); } auto run_invariant_checkers = [this, @@ -90,7 +91,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = std::string(name()) + ": pipeline start"; + string prefix = StrCat(name(), ": pipeline start"); bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +99,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(std::string(pass->name())) > 0) { + if (disabled_passes.count(string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -120,8 +121,8 @@ StatusOr HloPassPipeline::Run(HloModule* module) { TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), std::string(pass->name())); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a42d7e59fed2d838dfe3cb7f99e6b946edfdb0b4..1d41a4dac1d8e2f392be0e4e856ead36a5b71d68 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +34,7 @@ namespace xla { class HloPassPipeline : public HloPassInterface { public: explicit HloPassPipeline(const string& name) : name_(name) {} - tensorflow::StringPiece name() const override { return name_; } + absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the // pass constructor: diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index b9cca138703c8fa61aadf69dd7304a215a9f4be2..c3cacd7ce6b1ea3ad7cf84e898f274ae12622ac5 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 01b088a957554821e65db7bf9cedf334db49728f..961930f0a888e90f86e4354fa1373a303af8ec2f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions) + absl::Span instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { @@ -29,7 +29,7 @@ HloReachabilityMap::HloReachabilityMap( } bool HloReachabilityMap::SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; @@ -38,13 +38,13 @@ bool HloReachabilityMap::SetReachabilityToUnion( } void HloReachabilityMap::FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction) { SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); } void HloReachabilityMap::SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 48215d32a8284919cce6beb1663e6a723eefc1c4..b66a2aa4bd2b00a88cdbfa6b41c9123bb370aa87 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ class HloReachabilityMap { // Sets up a graph with no edges and where the nodes correspond to the given // instructions. explicit HloReachabilityMap( - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where @@ -54,13 +54,12 @@ class HloReachabilityMap { // vector in the internal graph of this HloReachabilityMap for the given // instruction and does not transitively update any other part of the // adjacency matrix. - bool SetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, - const HloInstruction* instruction); + bool SetReachabilityToUnion(absl::Span inputs, + const HloInstruction* instruction); // As above, but faster because it does not check if the reachability changed. void FastSetReachabilityToUnion( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true @@ -141,7 +140,7 @@ class HloReachabilityMap { // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. void SetReachabilityToUnionHelper( - tensorflow::gtl::ArraySlice inputs, + absl::Span inputs, const HloInstruction* instruction, BitVector* bit_vector); // Return the index of the given instruction. The value is used to index into diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index cf0be30c7ad5cbeb7fd3d71c7c649b6b448360b8..c9629926eae5132f683a353a430a724a66ef3d60 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,6 +20,10 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -37,17 +41,13 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Potential optimizations: // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue // of candidates. @@ -88,7 +88,7 @@ bool CanBeRematerialized( // Type holding a unique identifier for each Buffer object. using BufferId = int64; -using BufferIdList = tensorflow::gtl::InlinedVector; +using BufferIdList = absl::InlinedVector; // We wrap HloInstruction* with an Item that holds auxiliary // per-instruction state. @@ -123,7 +123,7 @@ struct Item { int64 position; }; -using ItemList = tensorflow::gtl::InlinedVector; +using ItemList = absl::InlinedVector; // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. @@ -202,15 +202,14 @@ class InstructionList { // On object construction this ordinal is precisely the instruction's index // in the list. Later, instructions inserted via InsertBefore receive // duplicate values. However, monotonicity is preserved. - void InsertBeforeInstructions( - Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { + void InsertBeforeInstructions(Item* to_insert, + absl::Span before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" - << tensorflow::str_util::Join(before_instructions, ", ", - [](string* out, Item* item) { - tensorflow::strings::StrAppend( - out, item->instruction->name()); - }) + << absl::StrJoin(before_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in @@ -393,10 +392,9 @@ class MemoryUsageTracker { int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat( - "Buffer ", id, " (defined by ", - defining_instruction->instruction->name(), ", size ", size, - " bytes)"); + return absl::StrCat("Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", + size, " bytes)"); } }; @@ -740,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, } string MemoryUsageTracker::ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend( - &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", - memory_usage(), " bytes)"); + string output = + absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n"); + absl::StrAppend(&output, + "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* instruction = item->instruction; string inprogress = item == in_progress_item_ ? " in-progress" : ""; string placed = item->placed ? " placed" : ""; - tensorflow::strings::StrAppend(&output, " ", instruction->name(), - inprogress, placed, "\n Defines:\n"); + absl::StrAppend(&output, " ", instruction->name(), inprogress, placed, + "\n Defines:\n"); for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; - tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, - ", ", buffer.unfinished_user_count, - " unfinished uses\n"); + absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", + buffer.unfinished_user_count, " unfinished uses\n"); } - tensorflow::strings::StrAppend(&output, " Uses:\n"); + absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { - tensorflow::strings::StrAppend(&output, " ", - buffers_[buffer_id].ToString(), "\n"); + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } } return output; @@ -780,10 +776,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( defined_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); for (const Buffer& buffer : buffers_) { @@ -803,10 +798,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( used_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); } for (const Buffer& buffer : buffers_) { @@ -1209,6 +1203,49 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + + // Create initial sequence of HLO instructions. + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( + *module, + [this](const BufferValue& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); + if (copy_insertion) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + + // First create a copy of the schedule which contains HloInstruction unique + // ids instead of HloInstruction*. This is necessary for updating the + // schedule below. + // TODO(b/113175018): Remove this when the HLO schedule is self-contained + // and can update itself. + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(*sequence); + + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); + + // RemoveUnnecessaryCopies only considers interference when determining + // whether it is legal to remove a copy. However, copies in the graph may be + // necessary for other reason such as preventing a constant from being live + // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. + // TODO(b/80249101): Break copy insertion into several passes and run each + // one once in the regular HLO pipeline. + TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); + + // The passes above can add and remove copies, update the schedule to + // account for these transformations. Newly added instructions will be + // placed ASAP in the schedule. + TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + + TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( + SequentialHloOrdering(module, *sequence), module)); + } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1230,24 +1267,6 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); - XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - } - // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1334,12 +1353,11 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << tensorflow::strings::Printf( - "Can't reduce memory use below %s (%lld bytes) by rematerialization; " - "only reduced to %s (%lld bytes)", - HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, - HumanReadableNumBytes(current_peak_memory).c_str(), - current_peak_memory); + LOG(WARNING) << absl::StrFormat( + "Can't reduce memory use below %s (%d bytes) by rematerialization; " + "only reduced to %s (%d bytes)", + HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b2725e2918ce76248d9f2cdbb2a6e5a63226bf9a..66ac1f66fd035074c69d070821a951fd0e357289 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +32,7 @@ limitations under the License. namespace xla { /*static*/ StatusOr> -HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, +HloRunner::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); @@ -106,7 +106,7 @@ StatusOr HloRunner::TransferLiteralToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals) { + const absl::Span literals) { std::vector buffers; for (const Literal* literal : literals) { CHECK(literal != nullptr); @@ -118,7 +118,7 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals) { + const absl::Span> literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { @@ -137,8 +137,8 @@ StatusOr> HloRunner::TransferLiteralFromDevice( StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { TF_ASSIGN_OR_RETURN(std::vector argument_buffers, TransferLiteralsToDevice(arguments)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, @@ -152,7 +152,7 @@ StatusOr> HloRunner::Execute( StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, + const absl::Span> arguments, bool run_hlo_passes, ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; @@ -169,8 +169,8 @@ StatusOr> HloRunner::Execute( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { // Get service run options. se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -190,8 +190,8 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( StatusOr HloRunner::ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes, ExecutionProfile* profile) { + const absl::Span arguments, bool run_hlo_passes, + ExecutionProfile* profile) { std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { @@ -226,14 +226,13 @@ StatusOr>> HloRunner::ExecuteReplicated( // no arguments. std::vector argument_buffer_ptrs( options.num_replicas * options.arguments.size() + 1); - std::vector> - argument_buffer_slices; + std::vector> argument_buffer_slices; int64 index = 0; for (int64 i = 0; i < options.num_replicas; ++i) { int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(MakeUnique(executor)); + streams.push_back(absl::make_unique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -260,7 +259,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = MakeUnique( + pool = absl::make_unique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -291,7 +290,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = MakeUnique(); + auto literal = absl::make_unique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 65537f07f56e74b7fe2c2f9792af21efc7229573..76d8b92bed484381a59d7f54e0a75bb7e75649ee 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -87,8 +87,7 @@ class HloRunner { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static StatusOr> CreateModuleFromString( - const tensorflow::StringPiece hlo_string, - const DebugOptions& debug_options); + const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. @@ -105,9 +104,9 @@ class HloRunner { // Transfers data between the host and device. StatusOr TransferLiteralToDevice(const Literal& literal); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice literals); + const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const tensorflow::gtl::ArraySlice> literals); + const absl::Span> literals); StatusOr> TransferLiteralFromDevice( const ShapedBuffer& buffer); @@ -118,24 +117,24 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice> arguments, + const absl::Span> arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr ExecuteWithDeviceBuffers( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, + const absl::Span arguments, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 27cc5361cde2fa021b9489f98217ae5648afc2ad..0fc3b268c059802a3882ad5032a9fe5da28cbf23 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include +#include #include #include @@ -28,16 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Class implementing a list scheduler of HLO instructions which produces a // sequence which minimizes memory usage by preferring to schedule the node that // frees bigger buffer and defines smaller outputs. @@ -582,4 +581,187 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { + tensorflow::gtl::FlatMap> id_sequence; + for (const auto& computation_sequence : sequence) { + for (const HloInstruction* instruction : computation_sequence.second) { + id_sequence[computation_sequence.first].push_back( + instruction->unique_id()); + } + } + return id_sequence; +} + +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence) { + // Map from unique ID to HloInstruction pointer for instructions in the + // module. + tensorflow::gtl::FlatMap id_to_instruction; + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK( + id_to_instruction.insert({instruction->unique_id(), instruction}) + .second); + } + for (int id : id_sequence.at(computation)) { + ids_in_schedule.insert(id); + } + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // module, but not in schedule) which use X. If an instruction is not in the + // map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + // For each computation, this is the set of newly added instructions which + // have no operands. These must be handled specially and are added to the + // beginning of the schedule. + tensorflow::gtl::FlatMap> + new_zero_operand_instructions; + for (const HloComputation* computation : nonfusion_computations) { + new_zero_operand_instructions[computation] = {}; + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + if (instruction->operands().empty()) { + new_zero_operand_instructions[computation].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + for (const HloComputation* computation : nonfusion_computations) { + std::vector old_computation_sequence = + std::move(sequence->at(computation)); + sequence->at(computation).clear(); + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + for (const HloInstruction* instruction : + new_zero_operand_instructions.at(computation)) { + worklist.push(instruction); + } + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + sequence->at(computation).push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : id_sequence.at(computation)) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. + continue; + } + const HloInstruction* instruction = it->second; + worklist.push(instruction); + schedule_worklist(); + } + } + + TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); + return Status::OK(); +} + +Status VerifySchedule( + const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence) { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(2, module.ToString()); + VLOG(2) << sequence; + + // Verify the set of computations in the sequence is exactly the set of + // computations in the module. + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); + tensorflow::gtl::FlatSet computations_in_module( + module.computations().begin(), module.computations().end()); + for (const auto& computation_sequence : sequence) { + TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : sequence.at(computation)) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 2b33ccc8bfb895286bb3747aab0a16cf25e2cfae..d06b8d9a5cdef82380bd68ae0991a3957db80f48 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -85,6 +85,43 @@ StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// Transforms the given schedule such that it is (again) a valid schedule for +// the module. This is used to update a schedule after the HLO module has been +// transformed in some way. In general, the only transformations to the module +// for which a schedule can be updated is the addition or removal of +// instructions to/from the module. Updating the schedule after new dependencies +// between existing instructions in the module is not supported and may result +// in an error status returned. +// +// Instructions in the module which also exist in the given schedule will remain +// in the same order in the updated schedule. Instructions which exist in the +// module but not in the given schedule will be placed as early as possible in +// the updated schedule. +// +// 'id_sequence' is a mirror of the given schedule 'sequence' but with +// HloInstruction ids rather than HloInstruction pointers. This should be +// constructed using ComputeIdSchedule below after the schedule is constructed +// but before the HLO module is transformed. +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence); + +// Constructs a copy of the given schedule but with HloInstruction unique ids +// rather than HloInstruction pointers. This is necessary for updating a +// schedule as HloInstruction points in the schedule may become invalid if +// instructions are removed from the module. Used by UpdateSchedule above.. +// TODO(b/113175018): Remove this function when HLO schedule is its own class. +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); + +// Verifies that the given schedule is valid for the given module. Specifically, +// the schedule contains exactly the instructions in the module and every +// dependency in the module is satisfied in the schedule. +Status VerifySchedule(const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 9ec983c2bc353955cb23d441d200ac8aa36951b1..d49d09d459758840ce0f9f0b05e3c033da3337f8 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -244,9 +246,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The max mem doesn't change - // because the while body isn't live during the peak. - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); @@ -267,7 +269,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto abs_abs1 = builder.AddInstruction( HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( - tensorflow::gtl::ArraySlice({abs_abs1}))); + absl::Span({abs_abs1}))); auto tuple_elm = builder.AddInstruction( HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); @@ -350,7 +352,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto module = CreateNewModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); // param != 0 // Needs 17 bytes @@ -408,12 +409,259 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations - EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } +TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + std::vector entry_schedule = sequence.begin()->second; + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(entry_schedule, sequence.begin()->second); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), + hlo) != sequence.at(entry).end(); + }; + + EXPECT_EQ(sequence.at(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 4); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 3); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 2); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(body).size(), 7); + EXPECT_EQ(sequence.at(cond).size(), 4); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(body).size(), 1); + EXPECT_EQ(sequence.at(cond).size(), 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 879fb3bbab2ada0f924282f16b3d9ccb4c2cb203..de7e6b53d4d2aa88e2213248370b4da82bdeadeb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; +using absl::StrCat; +using absl::StrJoin; HloSharding HloSharding::AssignDevice(int64 device_id) { return HloSharding(device_id); @@ -53,9 +54,8 @@ HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { return HloSharding(flattened_list); } -HloSharding HloSharding::Tuple( - const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { +HloSharding HloSharding::Tuple(const Shape& tuple_shape, + absl::Span shardings) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); for (auto& sharding : shardings) { CHECK(!sharding.IsTuple()) << sharding.ToString(); @@ -71,12 +71,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -92,7 +89,7 @@ string HloSharding::ToString() const { for (const HloSharding& element : tuple_elements_) { parts.push_back(element.ToString()); } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + return StrCat("{", absl::StrJoin(parts, ", "), "}"); } if (replicated_) { @@ -101,8 +98,8 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", - Join(tile_assignment_, ","), "}"); + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), + "]", StrJoin(tile_assignment_, ","), "}"); } } @@ -144,7 +141,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!maximal_); CHECK(!IsTuple()); std::vector ret_index; - tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { + tile_assignment_.Each([&](absl::Span index, int64 d) { if (d == device) { ret_index = {index.begin(), index.end()}; } @@ -153,8 +150,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { return ret_index; } -int64 HloSharding::DeviceForTileIndex( - tensorflow::gtl::ArraySlice index) const { +int64 HloSharding::DeviceForTileIndex(absl::Span index) const { CHECK(!replicated_); CHECK(!IsTuple()); if (maximal_) { @@ -244,16 +240,16 @@ StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree(shape, *this)); } -tensorflow::gtl::optional HloSharding::UniqueDevice() const { +absl::optional HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - tensorflow::gtl::optional unique_device; + absl::optional unique_device; for (auto& tuple_sharding : tuple_elements_) { auto device = tuple_sharding.UniqueDevice(); if (!device || (unique_device && *device != *unique_device)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } unique_device = device; } @@ -262,7 +258,7 @@ tensorflow::gtl::optional HloSharding::UniqueDevice() const { if (!replicated_ && maximal_) { return static_cast(*tile_assignment_.begin()); } - return tensorflow::gtl::nullopt; + return absl::nullopt; } int64 HloSharding::GetUniqueDevice() const { @@ -321,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, int32 core) { + [&](absl::Span indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { @@ -431,29 +427,39 @@ Shape HloSharding::TileShape(const Shape& shape) const { HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - - Shape sub_shape = ShapeUtil::GetSubshape(shape, index); - ShapeTree sub_shape_tree(sub_shape, Replicate()); - sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) - : sub_shape_tree.element(ShapeIndex({})); + int64 sharding_index = 0; + const Shape* sub_shape = &shape; + for (int64 idx : index) { + for (int64 i = 0; i < idx; ++i) { + sharding_index += + ShapeUtil::GetLeafCount(ShapeUtil::GetSubshape(*sub_shape, {i})); + } + sub_shape = &ShapeUtil::GetSubshape(*sub_shape, {idx}); + } + if (ShapeUtil::IsTuple(*sub_shape)) { + auto begin_it = tuple_elements_.begin() + sharding_index; + std::vector sub_shardings( + begin_it, begin_it + ShapeUtil::GetLeafCount(*sub_shape)); + return HloSharding::Tuple(*sub_shape, sub_shardings); + } else { + return tuple_elements_[sharding_index]; + } } -tensorflow::gtl::optional HloSharding::ExtractSingleSharding() - const { +absl::optional HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return tensorflow::gtl::optional(); + return absl::nullopt; } } return tuple_elements_.front(); } size_t HloSharding::Hash() const { - if (!tuple_) { + if (tuple_) { size_t h = 0; for (const auto& element : tuple_elements_) { h = tensorflow::Hash64Combine(h, element.Hash()); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 894783e5d1538fa4e8e91b65827121f32040af83..9775505f8608ced3e33abe376f4922cc6a972726 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -23,12 +23,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -66,7 +66,7 @@ class HloSharding { // shardings must match the number of leaf nodes in tuple_shape. For // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings); + absl::Span shardings); // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. @@ -132,7 +132,7 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. // REQUIRES: !IsTuple() - int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; + int64 DeviceForTileIndex(absl::Span index) const; // Given a device ID, returns the offset within the specified shape of the // tile that should be executed on the given core. This returns the lower @@ -151,7 +151,7 @@ class HloSharding { // span a single device, the return value will be empty. // In order for a sharding to span a single device, every leaf sharding must // be maximal and not replicated, and the used device must match. - tensorflow::gtl::optional UniqueDevice() const; + absl::optional UniqueDevice() const; // Retrieves the unique device or fails with a CHECK. int64 GetUniqueDevice() const; @@ -182,7 +182,7 @@ class HloSharding { // be returned. If it is a tuple, and all the tuple elements are common, the // common element will be returned. Otherwise the optional will contain no // value. - tensorflow::gtl::optional ExtractSingleSharding() const; + absl::optional ExtractSingleSharding() const; bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && @@ -260,9 +260,9 @@ class HloSharding { bool maximal_; bool tuple_; Array tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index a2c1d39d0d4893333b3c2ed0e3418b01dac8cefd..34cba6136ff3fe95529f3bcf594db7776c8bfd0a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -23,6 +24,23 @@ namespace xla { namespace { +// AssignmentKind and kUnassignedDevice are used during tuple domain sharding +// propagation in order to distinguish among three cases: +// kUnassigned: no assignment has occurred +// kAssigned: at least an assignment has occurred +// kConflict: no assignment has occurred because of conflicting propagations, +// which occurs when multiple users of an instruction have different +// shardings. +enum class AssignmentKind { kUnassigned, kAssigned, kConflict }; + +// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate +// absence of sharding information for that particular sub-sharding during +// sharding propagation. It is used to be able to express tuple shardings with +// partial information. At the end of the propagation the sharding of +// tuple-shaped instructions using kUnassignedDevice's is cleared. +// TODO(b/112883246): Centralized enum of reserved devices. +constexpr int64 kUnassignedDevice = -2; + struct PassThrough { PassThrough(HloInstruction* user, HloInstruction* operand) : user(user), operand(operand) {} @@ -117,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return Status::OK(); } -std::unique_ptr CloneShardingForDomain( - const HloSharding& sharding) { - auto single_sharding = sharding.ExtractSingleSharding(); +// For tuple shardings if every element have the same sharsing then we want to +// treat them as single element sharsings to insert less domain separation as a +// domain can prevent some optimizations and we want to minimize that from +// happening. +std::shared_ptr CloneShardingForDomain( + std::shared_ptr sharding) { + auto single_sharding = sharding->ExtractSingleSharding(); if (!single_sharding) { - return MakeUnique(sharding); + return sharding; } - return MakeUnique(*single_sharding); + return std::make_shared(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -142,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. -// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() -// sharding will be returned. -ShapeTree GetTupleSharding(HloInstruction* tuple) { - if (tuple->has_sharding()) { - return tuple->sharding().GetAsShapeTree(tuple->shape()); +// Return the ShapeTree of the user argument. The user argument +// is assumed to be a user of the instruction argument. +// If user is a tuple instruction, return the tuple subsharding corresponding to +// the operand matching the instruction argument, because that is the +// subsharding corresponding to instruction. +ShapeTree GetShardingTreeFromUser( + const HloInstruction& instruction, const HloInstruction& user) { + if (user.opcode() == HloOpcode::kTuple) { + return user.sharding() + .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) + .GetAsShapeTree(instruction.shape()); + } + return user.sharding().GetAsShapeTree(user.shape()); +} + +// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) +// then no assignment is made. Therefore kUnassignedDevice is never propagated. +// kConflict is returned if lhs is already assigned and rhs is assigned to a +// different device. +StatusOr AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { + TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); + if (rhs.UsesDevice(kUnassignedDevice)) { + return AssignmentKind::kUnassigned; + } + if (lhs->UsesDevice(kUnassignedDevice)) { + *lhs = rhs; + return AssignmentKind::kAssigned; + } + return lhs->UniqueDevice() != rhs.UniqueDevice() + ? AssignmentKind::kConflict + : AssignmentKind::kUnassigned; +} + +// Assigns the whole rhs tree to lhs_tree, starting at lhs_it. +// In case of conflicting assignment AssignmentKind::kConflict is returned. In +// this case lhs_tree is partially assigned, up to the conflicting leaf. It is +// up to the caller to discard the partial assignment in case of conflict. +StatusOr AssignTreeSharding( + ShapeTree* lhs_tree, ShapeTree::iterator lhs_it, + const ShapeTree& rhs_tree) { + AssignmentKind assigned = AssignmentKind::kUnassigned; + auto rhs_it = rhs_tree.begin(); + for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end(); + ++lhs_it, ++rhs_it) { + // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it) + if (rhs_tree.IsLeaf(rhs_it->first)) { + TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first)); + TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned, + AssignLeafSharding(&lhs_it->second, rhs_it->second)); + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we return conflict to the caller. At this point + // partial assignments to lhs_tree may have been made already. It is up + // to the caller to discard the partial assignment in case of conflict. + return AssignmentKind::kConflict; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } } - return ShapeTree(tuple->shape(), HloSharding::Replicate()); + TF_RET_CHECK(rhs_it == rhs_tree.end()); + return assigned; } -// Retrieves the sharding of operand, asked from a user instruction which is -// within domain. If operand is a kDomain, it means that sharding argument is -// the operand sharding, otherwise the operand's own sharding will be returned. -const HloSharding* GetOperandSharding(const HloInstruction* operand, +StatusOr ApplyShardingFromUsers(HloInstruction* instruction, const DomainMetadata::Domain& domain, - const HloSharding& sharding) { - // Here the user of operand is within the domain instruction set, and since it - // is user of operand, we need to look into the enter_domains set. If this is - // not a kDomain within the user domains set, then return the operand - // sharding, if any. - if (operand->opcode() != HloOpcode::kDomain || - domain.enter_domains.count(const_cast(operand)) == 0) { - return operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding& domain_sharding) { + if (instruction->users().empty()) { + // No sharding from users, use domain_sharding, after checking + // compatibility. + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + ShapeUtil::GetLeafCount(instruction->shape()) == + domain_sharding.tuple_elements().size()); + instruction->set_sharding(domain_sharding); + return true; + } + AssignmentKind assigned = AssignmentKind::kUnassigned; + // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple + // subshardings can result in a final sharding assignment containing + // kUnassignedDevice leaves, in case some tuple indexes are not used, or are + // used by users that don't have a sharding. + // Non-tuple shardings are either assigned to a real sharding, or are not + // assigned at all. As such they will never get assigned to kUnassignedDevice. + // In any case, kUnassignedDevice is never propagated, from the implementation + // of AssignLeafSharding. + ShapeTree sharding_tree( + instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(const_cast(user)) > 0) { + // If a user is a domain and it is registered in the domain exits, then + // the instruction sharding is taken directly from the domain, and no + // further users need to be visited. + instruction->set_sharding(domain_sharding); + return true; + } + if (!user->has_sharding()) { + continue; + } + AssignmentKind sub_assigned = AssignmentKind::kUnassigned; + ShapeTree user_sharding_tree = + GetShardingTreeFromUser(*instruction, *user); + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuple-shaped instructions collect individual tuple subshardings + // from the uses, and then combine them into the tuple sharding. + // If the user is a GTE its sharding concerns only the subtree of + // sharding_tree at index user->tuple_index, otherwise the whole + // sharding_tree is affected. + ShapeTree::iterator sharding_tree_begin = + user->opcode() == HloOpcode::kGetTupleElement + ? sharding_tree.find({user->tuple_index()}) + : sharding_tree.begin(); + TF_ASSIGN_OR_RETURN( + sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin, + user_sharding_tree)); + } else { + // Non-tuple shape: assign common users sharding. + TF_RET_CHECK(user_sharding_tree.leaf_count() == 1) + << "Expected non-tuple user sharding"; + TF_ASSIGN_OR_RETURN( + sub_assigned, + AssignTreeSharding(&sharding_tree, sharding_tree.begin(), + user_sharding_tree)); + } + + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we don't assign any sharding. + return false; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + + if (assigned == AssignmentKind::kAssigned) { + if (ShapeUtil::IsTuple(instruction->shape())) { + instruction->set_sharding(HloSharding::Tuple(sharding_tree)); + } else { + TF_RET_CHECK(sharding_tree.leaf_count() == 1); + instruction->set_sharding(sharding_tree.leaf_begin()->second); + } + return true; } - // At this point operand is a kDomain of the currently processed domain, so we - // can refer to sharding as the domain sharding. - return &sharding; + return false; } // Tries to propagate the sharding information into the instructions that are -// part of the domain, in a post order manner (operand propagate to user). +// part of the domain, in a reverse post order manner (users propoagate to +// instruction). StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& sharding) { + const HloSharding& domain_sharding) { int64 assigned = 0; - for (HloInstruction* instruction : domain.instructions) { + // domain.instructions are ordered in a post-order manner. As we do + // user->operand propagation we process instructions in reverse order. In so + // doing we are guaranteed to process all users before their operands. + for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend(); + ++it) { + HloInstruction* instruction = *it; if (instruction->has_sharding()) { continue; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* tuple = instruction->mutable_operand(0); - const HloSharding* tuple_sharding = - GetOperandSharding(tuple, domain, sharding); - if (tuple_sharding != nullptr) { - if (tuple_sharding->IsTuple()) { - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); - } else { - SetSingleSharding(instruction, *tuple_sharding); - } - ++assigned; - } - } else if (instruction->opcode() == HloOpcode::kTuple) { - int64 tuple_assigned = 0; - ShapeTree shape_tree = GetTupleSharding(instruction); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloSharding* operand_sharding = - GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr) { - HloSharding operand_subsharding = HloSharding::Replicate(); - if (operand_sharding == &sharding) { - operand_subsharding = - sharding.GetSubSharding(instruction->shape(), {i}); - operand_sharding = &operand_subsharding; - } - if (shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; - } - } - } - if (tuple_assigned > 0) { - HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); - VLOG(4) << " " << instruction->name() << " to sharding " - << tuple_sharding; - instruction->set_sharding(tuple_sharding); - ++assigned; - } - } else { - // If all the operand of the given instruction has the same single device - // assignment, assign that device to this instruction as well. - const HloSharding* common_sharding = nullptr; - for (const HloInstruction* operand : instruction->operands()) { - const HloSharding* operand_sharding = - GetOperandSharding(operand, domain, sharding); - if (operand_sharding != nullptr) { - if (common_sharding != nullptr && - *common_sharding != *operand_sharding) { - common_sharding = nullptr; - break; - } - common_sharding = operand_sharding; - } - } - if (common_sharding != nullptr) { - VLOG(4) << " " << instruction->name() << " to sharding " - << *common_sharding; - instruction->set_sharding(*common_sharding); - ++assigned; - } + // Take the sharding from the users. + TF_ASSIGN_OR_RETURN( + bool instruction_assigned, + ApplyShardingFromUsers(instruction, domain, domain_sharding)); + if (instruction_assigned) { + ++assigned; + VLOG(4) << " " << instruction->name() << " to sharding " + << instruction->sharding(); } } return assigned; @@ -261,83 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; - for (;;) { - TF_ASSIGN_OR_RETURN(int64 assigned, - ApplyDomainShardingPass(domain, sharding)); - if (assigned == 0) { - break; - } - } + TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status()); + int64 unassigned = 0; for (HloInstruction* instruction : domain.instructions) { if (!instruction->has_sharding()) { LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); ++unassigned; + } else { + // Un-set sharding of tuples whose sub-sgardings are assigned to + // kUnassignedDevice. Indeed in case of doubt it is better to leave the + // entire tuple unassigned, and let the device placer decide for it. + if (instruction->sharding().UsesDevice(kUnassignedDevice)) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + << "Only tuples can have kUnassignedDevice sub shardings"; + instruction->clear_sharding(); + } } } // Should we error out if unassigned > 0? return Status::OK(); } -// Creates a kDomain instruction to be placed between instruction and operand. -// The kDomain instruction will be created only if the sharding differ between -// the instruction and the operand. -std::unique_ptr CreateDomain(HloInstruction* instruction, - HloInstruction* operand) { - const HloSharding* instruction_sharding = - instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* operand_sharding = - operand->has_sharding() ? &operand->sharding() : nullptr; - // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && operand_sharding == nullptr) { - return nullptr; - } - // No need for domain if they match. - if (instruction_sharding != nullptr && operand_sharding != nullptr && - ShardingMatches(*instruction_sharding, *operand_sharding)) { - return nullptr; - } - std::unique_ptr real_instruction_sharding; - std::unique_ptr real_operand_sharding; - if (instruction_sharding != nullptr) { - real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); - } - if (operand_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*operand_sharding); - } - VLOG(3) << "Creating domain:"; - VLOG(3) << " Instruction: " << instruction->name(); - VLOG(3) << " Operand: " << operand->name(); - VLOG(3) << " User side sharding: " - << (real_instruction_sharding != nullptr - ? real_instruction_sharding->ToString() - : "None"); - VLOG(3) << " Operand side sharding: " - << (real_operand_sharding != nullptr - ? real_operand_sharding->ToString() - : "None"); - - std::unique_ptr operand_side_metadata = - MakeUnique(std::move(real_operand_sharding)); - std::unique_ptr user_side_metadata = - MakeUnique(std::move(real_instruction_sharding)); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); -} - -StatusOr> ExtractOriginalCommonSharding( - tensorflow::gtl::ArraySlice instructions) { +StatusOr> ExtractOriginalCommonSharding( + absl::Span instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the // original common sharding. // All the instructions passed to this API are part of the same computation. - const HloSharding* sharding = nullptr; + std::shared_ptr sharding; for (HloInstruction* instruction : instructions) { if (instruction->has_sharding()) { if (sharding == nullptr) { - sharding = &instruction->sharding(); + sharding = instruction->sharding_ptr(); } else { TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) << "Sharding " << *sharding << " does not match the one in " @@ -346,10 +391,10 @@ StatusOr> ExtractOriginalCommonSharding( } } if (sharding == nullptr) { - return std::unique_ptr(); + return std::shared_ptr(); } VLOG(4) << "Extracted sharding is " << *sharding; - return CloneShardingForDomain(*sharding); + return CloneShardingForDomain(sharding); } } // namespace @@ -357,9 +402,9 @@ StatusOr> ExtractOriginalCommonSharding( std::unique_ptr ShardingMetadata::Clone() const { std::unique_ptr sharding; if (sharding_ != nullptr) { - sharding = MakeUnique(*sharding_); + sharding = absl::make_unique(*sharding_); } - return MakeUnique(std::move(sharding)); + return absl::make_unique(std::move(sharding)); } bool ShardingMetadata::Matches(const DomainMetadata& other) const { @@ -403,7 +448,7 @@ Status ShardingMetadata::NormalizeShardingDomain( TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); } } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + TF_ASSIGN_OR_RETURN(std::shared_ptr sharding, ExtractOriginalCommonSharding(domain.instructions)); if (sharding != nullptr) { VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); @@ -415,9 +460,75 @@ Status ShardingMetadata::NormalizeShardingDomain( return Status::OK(); } -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand) { - return CreateDomain(instruction, operand); +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + auto instruction_sharding = instruction->sharding_ptr(); + auto root_sharding = root->sharding_ptr(); + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && root_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { + return nullptr; + } + + if (instruction_sharding != nullptr) { + instruction_sharding = CloneShardingForDomain(instruction_sharding); + } + if (root_sharding != nullptr) { + root_sharding = CloneShardingForDomain(root_sharding); + } + + auto it = domain_cse_map_.find({operand, instruction_sharding}); + if (it != domain_cse_map_.end()) { + return it->second; + } + + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (instruction_sharding != nullptr ? instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (root_sharding != nullptr ? root_sharding->ToString() : "None"); + + HloInstruction* domain = + operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, + absl::make_unique(root_sharding), + absl::make_unique(instruction_sharding))); + domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding}, + domain); + return domain; +} + +bool ShardingDomainCreator::DomainCseMapKey::operator==( + const ShardingDomainCreator::DomainCseMapKey& other) const { + if (instruction != other.instruction) { + return false; + } + if (sharding == nullptr && other.sharding == nullptr) { + return true; + } + if (sharding == nullptr || other.sharding == nullptr) { + return false; + } + return *sharding == *other.sharding; +} + +size_t ShardingDomainCreator::DomainCseMapHasher::operator()( + const ShardingDomainCreator::DomainCseMapKey& key) const { + return tensorflow::Hash64Combine( + std::hash{}(key.instruction), + key.sharding ? key.sharding->Hash() + : static_cast(0x297814aaad196e6dULL)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 5e01fc0e22ae8f3421c2cb5790adf44b1200a804..cba5db927a056c760e1c4a291d96cfdbca818029 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -16,23 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { // A DomainMetadata implementation that internally wraps a sharding attribute. class ShardingMetadata : public DomainMetadata { public: - explicit ShardingMetadata(std::unique_ptr sharding) + explicit ShardingMetadata(std::shared_ptr sharding) : sharding_(std::move(sharding)) {} std::unique_ptr Clone() const override; - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override; @@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata { const HloSharding* sharding() const { return sharding_.get(); } - static tensorflow::StringPiece KindName() { return "sharding"; } + static absl::string_view KindName() { return "sharding"; } static StatusOr ToShardingMetadata( const DomainMetadata* metadata); @@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata { const DomainMetadata* metadata); private: - std::unique_ptr sharding_; + std::shared_ptr sharding_; }; -// Given an HLO graph edge between instruction and one of its operands, creates -// a ShardingMetadata based kDomain instruction if the sharding between -// instruction and operand changes. Returns nullptr if there is no need for a -// domain separation. -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand); +// If the sharding between root and instruction changes then returns a +// ShardingMetadata based kDomain instruction what can be used to separate +// operand and instruction. +// Returns nullptr if there is no need for a domain separation. +class ShardingDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand); + + private: + // Map from instruction and user sharding to domain users to CSE identical + // domains. + struct DomainCseMapKey { + const HloInstruction* instruction; + std::shared_ptr sharding; + + bool operator==(const DomainCseMapKey& other) const; + }; + struct DomainCseMapHasher { + size_t operator()(const DomainCseMapKey& key) const; + }; + std::unordered_map + domain_cse_map_; +}; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 45fc300fcaf5a301fe11768da77a7c0907919c39..80634677e78e4a35dcb9bf7de018a88122c3c030 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -29,8 +29,8 @@ limitations under the License. namespace xla { namespace { -Array MakeArray(tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice contents) { +Array MakeArray(absl::Span dimensions, + absl::Span contents) { Array a(dimensions); std::copy(contents.begin(), contents.end(), a.begin()); return a; @@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) { } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index 2ef38821af632180714911c0ff22731fd559b915..d1cf644f8273e632e2952cca0da749616e9b6233 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -24,7 +24,7 @@ namespace xla { // one arbitrarily to use and delete the others. class HloSubcomputationUnification : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "subcomputation-unification"; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d..487653344976a10e18ba667085525ba1ecbb8612 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -21,28 +23,25 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; -using ::tensorflow::TensorShapeProto; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; -using ::tensorflow::str_util::Join; namespace xla { namespace hlo_graph_dumper { namespace { +using absl::StrAppend; +using absl::StrCat; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorShapeProto; + string GetOpDefName(const HloInstruction* instruction) { string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); + tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); if (instruction->opcode() == HloOpcode::kFusion) { string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + StrAppend(&name, absl::string_view(fusion_name).substr(1)); } return name; } @@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); + "{", + absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), + "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7fd99fc93050b386c5ad24e6dcd2fea1bf652c3f..773fc7d22537ab81d945c197b713b00d322a7f24 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -30,16 +32,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; const Shape& HloPosition::shape() const { return ShapeUtil::GetSubshape(instruction->shape(), index); @@ -150,7 +149,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace void HloValue::SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions) { + absl::Span positions) { CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; // The positions must be unique and should not contain the defining position @@ -216,14 +215,14 @@ void HloValueSet::SortAndUniquifyValues() { } string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return StrCat( + "HloValueSet: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } -bool HloValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { +bool HloValueSet::AssignUnionOf(absl::Span inputs) { HloValueSet union_set; for (const HloValueSet* input : inputs) { for (const HloValue* value : input->values()) { @@ -254,7 +253,7 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) { } bool InstructionValueSet::AssignUnionOf( - tensorflow::gtl::ArraySlice inputs) { + absl::Span inputs) { CHECK_GT(inputs.size(), 0); for (int i = 1; i < inputs.size(); ++i) { DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape())); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index a1151f65e07dffdcd52f645f61dcc9b4f26459c0..b6670d409b92e8be42f5cdb40fba8d662ae83958 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -20,13 +20,13 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -108,8 +108,7 @@ class HloValue : public BufferValue { // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not // be included in 'positions' as this is set at construction time. - void SetPositionsAndComputeUses( - tensorflow::gtl::ArraySlice positions); + void SetPositionsAndComputeUses(absl::Span positions); // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -186,14 +185,14 @@ class HloValueSet { public: HloValueSet() = default; - explicit HloValueSet(tensorflow::gtl::ArraySlice values) + explicit HloValueSet(absl::Span values) : values_(values.begin(), values.end()) { SortAndUniquifyValues(); } // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf(tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); // Return the vector of HloValues in the set. Values in the vector are unique // and stably sorted by value id. @@ -247,8 +246,7 @@ class InstructionValueSet : public ShapeTree { // Sets this value set to the union of the given value sets. Returns whether // this value set changed. - bool AssignUnionOf( - tensorflow::gtl::ArraySlice inputs); + bool AssignUnionOf(absl::Span inputs); string ToString() const; }; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e7674f3ddd5baa87c872d1c0b40bff340f3cd911..95516dec74bd253212901a3d9a92285d11fe122f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,11 +15,13 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -84,7 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers())); + convolution->window(), convolution->convolution_dimension_numbers(), + convolution->feature_group_count())); return CheckShape(convolution, expected); } @@ -114,6 +117,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -121,46 +129,35 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is " + "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), - instruction->ToString().c_str()); + operand_no, StringifyShape(token->shape()), instruction->ToString()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString().c_str(), - parameter->ToString().c_str(), - instruction->ToString().c_str()); + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); - // Infeed has an optional single token operand. - // TODO(b/80000000): Update when token is not optional. - if (infeed->operand_count() == 1) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); // The output of infeed is a tuple containing the data value and a token. return CheckShape(infeed, @@ -170,30 +167,20 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { HloOutfeedInstruction* outfeed = Cast(instruction); - // Outfeed has an optional token operand (operand 1). - // TODO(b/80000000): Update when token is not optional. - if (outfeed->operand_count() == 2) { - TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); - } + TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), - outfeed->ToString().c_str()); + StringifyShape(outfeed->operand(0)->shape()), + StringifyShape(outfeed->outfeed_shape()), outfeed->ToString()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } -Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return Status::OK(); -} - bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape) { @@ -207,7 +194,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (instruction->operand_count() != 2) { return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } const Shape& shape_0 = instruction->operand(0)->shape(); @@ -215,14 +202,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { return InternalError( "Expected scalar types for the two operands of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { return InternalError( "Expected compatible element types for the result and the two operands" " of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } PrimitiveType element_type = shape_0.element_type(); @@ -235,7 +222,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { "Element type not supported." " Expected element to be of floating point type, integral type or" " predicate type for RngUniform: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; @@ -244,13 +231,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { return InternalError( "Element type not supported." " Expected element to be FloatingPointType for RngNormal: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; default: return InternalError( "Invalid Rng distribution %s", - RandomDistribution_Name(instruction->random_distribution()).c_str()); + RandomDistribution_Name(instruction->random_distribution())); } return Status::OK(); @@ -269,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()), + StringifyShape(sort->operand(1)->shape())); } return CheckVariadicShape(sort); } @@ -279,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { @@ -293,14 +288,13 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { - if (!ShapeUtil::IsArray(reduce->shape())) { - return InvalidArgument("Variadic reduce is not supported."); + std::vector operand_shapes; + for (const HloInstruction* operand : reduce->operands()) { + operand_shapes.push_back(&operand->shape()); } - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - {&reduce->operand(0)->shape(), &reduce->operand(1)->shape()}, - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); + return CheckShape(reduce, ShapeInference::InferReduceShape( + operand_shapes, reduce->dimensions(), + reduce->to_apply()->ComputeProgramShape())); } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { @@ -344,7 +338,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %d and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -426,12 +431,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape)); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -562,7 +566,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", - instruction->ToString().c_str()); + instruction->ToString()); } return Status::OK(); })); @@ -579,7 +583,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather, ShapeInference::InferGatherShape( gather->operand(0)->shape(), gather->operand(1)->shape(), - gather->gather_dimension_numbers(), gather->gather_window_bounds())); + gather->gather_dimension_numbers(), gather->gather_slice_sizes())); } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { @@ -609,53 +613,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); + StringifyShape(inferred_shape), StringifyShape(instruction->shape()), + instruction->ToString()); } return Status::OK(); } @@ -697,12 +699,11 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { instruction->opcode(), instruction->operands())); } -string ComputationsToString( - tensorflow::gtl::ArraySlice computations) { - return tensorflow::str_util::Join( - computations, ",", [](string* s, const HloComputation* computation) { - s->append(computation->name()); - }); +string ComputationsToString(absl::Span computations) { + return absl::StrJoin(computations, ",", + [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); } // Verifies various invariants about the structure of the HLO: @@ -720,23 +721,23 @@ Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { return InternalError("Computation %s has a null parent pointer", - computation->name().c_str()); + computation->name()); } if (computation->parent() != module) { return InternalError( "Computation %s parent() does not point to parent module", - computation->name().c_str()); + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { return InternalError("Instruction %s has a null parent pointer", - instruction->name().c_str()); + instruction->name()); } if (instruction->parent() != computation) { return InternalError( "Instruction %s parent() does not point to parent computation", - instruction->name().c_str()); + instruction->name()); } } } @@ -753,9 +754,8 @@ Status VerifyHloStructure(HloModule* module) { return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", - i, operand->name().c_str(), instruction->name().c_str(), - operand->parent()->name().c_str(), - instruction->parent()->name().c_str()); + i, operand->name(), instruction->name(), + operand->parent()->name(), instruction->parent()->name()); } } } @@ -771,7 +771,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { "Instruction of fused computation does not match expected " "instruction " "%s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Fused root instruction and fused parameters must all be owned by the @@ -785,7 +785,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_root == instruction) { if (root_owned) { return InternalError("Root appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } root_owned = true; } @@ -793,7 +793,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return InternalError("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } parameter_owned[i] = true; } @@ -801,20 +801,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } if (!root_owned) { return InternalError("Root not found in computation of %s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", fusion->ToString()); } // All uses of fused instructions must be in the fusion computation, and @@ -824,54 +823,46 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (instruction != fused_root) { if (instruction->user_count() == 0) { return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), - fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { return InternalError( "Non-root instruction %s in %s may not have external users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } } } } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %lld in %s.", - param_no, fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { return InternalError( - "Unexpected parameter number %lld in %s: higher then number of " + "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", - param_no, fusion->ToString().c_str(), fused_parameters.size()); + param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { return InternalError( - "Did not expect parameter number %lld more than once in %s.", - param_no, fusion->ToString().c_str()); + "Did not expect parameter number %d more than once in %s.", param_no, + fusion->ToString()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in " - "%s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } @@ -886,18 +877,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { auto* while_body = instruction->while_body(); if (while_cond->num_parameters() != 1) { return FailedPrecondition( - "While condition must have exactly 1 parameter; had %lld : %s", - while_cond->num_parameters(), while_cond->ToString().c_str()); + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); } if (while_body->num_parameters() != 1) { return FailedPrecondition( - "While body must have exactly 1 parameter; had %lld : %s", - while_body->num_parameters(), while_body->ToString().c_str()); + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); } if (instruction->operand_count() != 1) { return FailedPrecondition( - "While loop must have exactly one operand; had %lld : %s", - instruction->operand_count(), instruction->ToString().c_str()); + "While loop must have exactly one operand; had %d : %s", + instruction->operand_count(), instruction->ToString()); } return Status::OK(); } @@ -905,16 +896,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %lld", - instruction->true_computation()->name().c_str(), - instruction->ToString().c_str(), + "True computation %s of %s must have 1 parameter insted of %d", + instruction->true_computation()->name(), instruction->ToString(), instruction->true_computation()->num_parameters()); } if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %lld", - instruction->false_computation()->name().c_str(), - instruction->ToString().c_str(), + "False computation %s of %s must have 1 parameter insted of %d", + instruction->false_computation()->name(), instruction->ToString(), instruction->false_computation()->num_parameters()); } return Status::OK(); @@ -927,11 +916,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", - HloOpcodeString(instruction->opcode()).c_str(), - ShapeUtil::HumanString(out_shape).c_str(), - ShapeUtil::HumanString(operand_shape).c_str()); + HloOpcodeString(instruction->opcode()), + ShapeUtil::HumanString(out_shape), + ShapeUtil::HumanString(operand_shape)); } } return Status::OK(); @@ -962,7 +951,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { if (ShapeContainsToken(param->shape())) { return InternalError( "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape()).c_str()); + ShapeUtil::HumanString(param->shape())); } } return Status::OK(); @@ -974,9 +963,9 @@ Status CheckSameChannel(const HloInstruction* instr1, if (instr1->channel_id() != instr2->channel_id()) { return InternalError( "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); } return Status::OK(); } @@ -997,7 +986,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, "Expected instructions to have the same is-host-transfer property: " "%s, " "%s ", - instr1->ToString().c_str(), instr2->ToString().c_str()); + instr1->ToString(), instr2->ToString()); } return Status::OK(); } @@ -1014,12 +1003,12 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: " + "Channel %d is used for multiple host send/recv instructions: " "%s " "and " "%s", - sendrecv->channel_id(), sendrecv->ToString().c_str(), - it_inserted.first->second->ToString().c_str()); + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); } } @@ -1078,9 +1067,9 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(instruction->parent() == computation); if (instruction->opcode() == HloOpcode::kFusion) { TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK( - ContainersEqual(instruction->called_computations(), - {instruction->fused_instructions_computation()})) + TF_RET_CHECK(instruction->called_computations() == + absl::Span( + {instruction->fused_instructions_computation()})) << "Fusion HLO calls computations other than the " "fused_instructions_computation: " << instruction->ToString() diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index c942fab08e1ace75bccb8762954787a4366922a9..42e3027bf14a827bd0a791510c2d9c107d989ab9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/shape_inference.h" namespace xla { @@ -27,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -46,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -63,7 +65,6 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; - Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -106,13 +107,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - // Return true if the shapes of the two operands have the same element type, - // and the result shape either has the same element type as the operand - // shapes or mixed precision is allowed and the result shape and the operand - // shapes have floating point element types. + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape); + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -125,14 +155,10 @@ class HloVerifier : public HloPassInterface { public: using ShapeVerifierFactory = std::function()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return MakeUnique(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return MakeUnique(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -140,10 +166,9 @@ class HloVerifier : public HloPassInterface { : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; - tensorflow::StringPiece name() const override { return "verifier"; } + absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d764964f3c3dc58a54bd0307f8b625076c14f3e5..0cac210c2413e979300e191cb54860bcd0ab79b5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -34,16 +34,20 @@ namespace { using ::testing::HasSubstr; +// This class cannot be converted to use HloVerifiedTestBase. It explicitly +// uses HloTestBase to create and test malformed HLOs. class HloVerifierTest : public HloTestBase { public: HloVerifierTest() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; class HloVerifierTestAllowMixedPrecision : public HloTestBase { public: HloVerifierTestAllowMixedPrecision() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; TEST_F(HloVerifierTest, NullInstructionParent) { @@ -275,5 +279,84 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) { EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); } +TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(F32).CloneToUnique())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(F32).CloneToUnique())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +// Simple module containing a convolution as the root. +static const char* const kConvHloString = R"( +HloModule module +ENTRY entry_computation { + param0 = f16[128,128,56,56] parameter(0) + param1 = f16[3,3,128,128] parameter(1) + zero_f16 = f16[] constant(0) + ROOT conv = f16[128,128,28,28] convolution(param0, param1), + window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01 +})"; + +TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_window_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive window dilation factor")); +} + +TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_base_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive base area dilation factor")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index bb5b40a8a87c5eab5a5b1599581a81bbd064511b..e76b93107c923b41666f6b0a388dda143a8cb50a 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -14,27 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -using tensorflow::strings::Appendf; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; +using absl::StrFormat; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; - Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", - computation_name_.c_str(), - HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_, + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_))); int64 cumulative_cycles = 0; auto print_op = [&](const OpInfo& op, bool is_total = false) { @@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const { if (op.bytes_accessed > op.cycles) { bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = Printf("%.3fB/cycle", bpc); + bytes_per_cycle = StrFormat("%.3fB/cycle", bpc); } } @@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const { // columns in the output. cycles_percent_str = "100.% 100Σ"; } else { - cycles_percent_str = - Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); + cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent, + cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf( + StrAppendFormat( &s, - "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " "%16s :: %s\n", - op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles), op.optimal_seconds < 0 ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), + op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + bytes_per_sec, bytes_per_cycle, op.name); }; float optimal_seconds_sum = 0.0; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d..925111fa1f1e48650b0089f402d92e431043eabe 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,10 +29,10 @@ namespace xla { // computation, suitable for consumption by humans. class HumanReadableProfileBuilder { public: - explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(std::string(computation_name)), + : computation_name_(computation_name), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -43,15 +43,13 @@ class HumanReadableProfileBuilder { // Adds an operation to the profile. If you don't know the number of // floating-point ops or bytes touched by the op, or if you don't know how // fast it would run optimally, pass -1 for that param. - void AddOp(tensorflow::StringPiece op_name, - tensorflow::StringPiece short_name, - tensorflow::StringPiece category, int64 cycles, int64 flop_count, + void AddOp(absl::string_view op_name, absl::string_view short_name, + absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back({std::string(op_name), std::string(short_name), - std::string(category), cycles, flop_count, - transcendental_count, bytes_accessed, - optimal_seconds}); + op_infos_.push_back({string(op_name), string(short_name), string(category), + cycles, flop_count, transcendental_count, + bytes_accessed, optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index aa325dc8a353c5bfbfded0c2774c66bfcc71c9cb..85bb4a8b2450a48d461f1d84e0609a38a6818d9c 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface { ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "implicit-broadcast-remover"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 3531b7223fb11df212fa8d30e3adba6aac6c5679..a4de02a89039e07b22b1ad8c268c2f760aa95880 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -14,13 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -31,32 +34,29 @@ using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; -using tensorflow::gtl::ArraySlice; -using tensorflow::str_util::Join; +using absl::StrJoin; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as(); - return tensorflow::strings::StrCat("%", - unknown_tensor->instruction().name()); + return absl::StrCat("%", unknown_tensor->instruction().name()); } case Array::kConstant: { if (print_constants) { string contents = root->as()->literal()->ToString(); - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, - ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + " ", contents, ")"); } - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + ")"); } case Array::kReshaped: { ReshapedArray* reshaped_array = root->as(); - return tensorflow::strings::StrCat( + return absl::StrCat( "(reshape ", ToString(reshaped_array->operand(), print_constants), " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); } @@ -67,11 +67,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { string name = root->kind() == Array::kScalarIndexedConstant ? "scalar-indexed-const" : "scalar-indexed"; - return tensorflow::strings::StrCat( + return absl::StrCat( "(", name, " ", ToString(indexed_array->source(), print_constants), " ", ToString(indexed_array->indices(), print_constants), " ", indexed_array->source_dim(), "->[", - Join(indexed_array->output_dims(), ","), "])"); + StrJoin(indexed_array->output_dims(), ","), "])"); } } } @@ -92,7 +92,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( // Depth first search over the DAG, invoking ComputeArrayFor in post order. // The HLO instructions already in the cache are considered leaves. - gtl::InlinedVector stack; + absl::InlinedVector stack; enum DfsState { kDiscovered, kVisited }; gtl::FlatMap dfs_state_map; @@ -153,7 +153,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), + instr->gather_slice_sizes(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kReshape) { @@ -185,7 +185,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape) { + absl::Span output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). // `source` is the inner Gather(A, X). @@ -251,24 +251,22 @@ StatusOr IndexedArrayAnalysis::FoldGatherOfGather( StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, - Array* indices) { + absl::Span slice_sizes, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } - CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + CHECK_EQ(dim_numbers.start_index_map_size(), 1); - // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should - // it become relevant. + // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here, + // should it become relevant. - if (dim_numbers.elided_window_dims_size() != 1 || - dim_numbers.elided_window_dims(0) != - dim_numbers.gather_dims_to_operand_dims(0)) { + if (dim_numbers.collapsed_slice_dims_size() != 1 || + dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) { VLOG(3) << "ComputeArrayForGather: gather operations must elide " - "gather_dims_to_operand_dims[0] and " - "gather_dims_to_operand_dims[0] only"; + "start_index_map[0] and " + "start_index_map[0] only"; return nullptr; } @@ -277,27 +275,27 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( // arrays from an array of size [7,4,6]. We check that condition down below: for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { - if (i != dim_numbers.elided_window_dims(0) && - source->shape().dimensions(i) != window_bounds[i]) { - VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + if (i != dim_numbers.collapsed_slice_dims(0) && + source->shape().dimensions(i) != slice_sizes[i]) { + VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i << "] != source->shape().dimensions(" << i << ") -- " - << source->shape().dimensions(i) << " vs. " << window_bounds[i] - << " with dim_numbers.elided_window_dims(0) = " - << dim_numbers.elided_window_dims(0); + << source->shape().dimensions(i) << " vs. " << slice_sizes[i] + << " with dim_numbers.collapsed_slice_dims(0) = " + << dim_numbers.collapsed_slice_dims(0); return nullptr; } } - int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + int64 source_dim = dim_numbers.start_index_map(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } if (auto* indexed = dynamic_cast(source)) { - if (c_linear_search(indexed->output_dims(), source_dim)) { + if (absl::c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -314,8 +312,8 @@ namespace { // Returns an index into `values` such that the product of the range // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. -int64 FindSuffixWithProduct(ArraySlice values, int64 product) { - DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); +int64 FindSuffixWithProduct(absl::Span values, int64 product) { + DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; int64 i; @@ -343,7 +341,8 @@ struct ReshapePassthroughDimPair { // The returned vector of pairs is sorted in both the result_dim and the // operand_dim components. std::vector ComputeReshapePassthroughDimPairs( - ArraySlice operand_shape, ArraySlice result_shape) { + absl::Span operand_shape, + absl::Span result_shape) { // A reshape can be seen as an index mapping from output index to input index: // // (i_0, ..., i_n) = f(o_0, ..., o_m) @@ -378,8 +377,8 @@ std::vector ComputeReshapePassthroughDimPairs( CHECK_NE(candidate_operand_dim, 0) << "result_dim = " << result_dim << ", result_subarray_size = " << result_subarray_size - << ", result_shape = [" << Join(result_shape, ",") << "]" - << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + << ", result_shape = [" << StrJoin(result_shape, ",") << "]" + << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -389,26 +388,27 @@ std::vector ComputeReshapePassthroughDimPairs( result_subarray_size *= result_shape[result_dim]; } - c_reverse(result); + absl::c_reverse(result); if (VLOG_IS_ON(3)) { std::vector result_strings; - c_transform(result, std::back_inserter(result_strings), - [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat(value.result_dim, "->", - value.operand_dim); - }); - VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" - << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + absl::c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return absl::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" + << StrJoin(result_shape, ",") << "] passthrough indices are [" + << StrJoin(result_strings, ",") + << "] (legend: `result`->`operand`)"; } - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.result_dim < rhs.result_dim; })); - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.operand_dim < rhs.operand_dim; })); @@ -419,30 +419,31 @@ std::vector ComputeReshapePassthroughDimPairs( // Return true if `dim` is stated as an passthrough operand dim in // `passthrough_dims`. bool IsReshapePassthroughOperandDim( - ArraySlice passthrough_dims, int64 dim) { - return c_any_of(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == dim; - }); + absl::Span passthrough_dims, int64 dim) { + return absl::c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); } // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( - ArraySlice passthrough_dims, int64 operand_dim) { - auto it = c_find_if(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == operand_dim; - }); + absl::Span passthrough_dims, + int64 operand_dim) { + auto it = absl::c_find_if( + passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); CHECK(it != passthrough_dims.end()); return it->result_dim; } -int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, - ArraySlice result_shape, - int64 source_passthrough_dim) { +int64 FindSourcePositionForPassthroughResultDim( + absl::Span operand_shape, absl::Span result_shape, + int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" - << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; int64 indexed_source_subarray_size = @@ -454,8 +455,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, Shape StripDegenerateDimensions(const Shape& shape) { DimensionVector new_dims; - c_copy_if(shape.dimensions(), std::back_inserter(new_dims), - [](int64 dim) { return dim != 1; }); + absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); return ShapeUtil::MakeShape(shape.element_type(), new_dims); } }; // namespace @@ -498,7 +499,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { if (shape.dimensions(i) == 1) { degenerate_dims_seen++; - } else if (ArrayContains(operand->output_dims(), i)) { + } else if (absl::c_linear_search(operand->output_dims(), i)) { new_output_dims.push_back(i - degenerate_dims_seen); } } @@ -518,8 +519,7 @@ IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( } StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims) { + ScalarIndexedArray* operand, absl::Span degenerate_dims) { if (degenerate_dims.empty()) { return operand; } @@ -531,7 +531,7 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( // element is true iff the i'th component of the result index is an output // index. - gtl::InlinedVector output_dims_bitvector( + absl::InlinedVector output_dims_bitvector( operand->shape().dimensions_size()); for (int64 output_dim : operand->output_dims()) { output_dims_bitvector[output_dim] = true; @@ -553,8 +553,8 @@ StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( }(); DimensionVector new_result_shape_dims; - c_copy(operand->shape().dimensions(), - std::back_inserter(new_result_shape_dims)); + absl::c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); for (int64 degenerate_dim : degenerate_dims) { InsertAt(&new_result_shape_dims, degenerate_dim, 1); } @@ -695,8 +695,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( operand_dim); }; - if (!c_all_of(scalar_indexed->output_dims(), - is_reshape_passthrough_operand_dim)) { + if (!absl::c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { VLOG(3) << "Not all output dims are passthrough dims " << ToString(scalar_indexed); return nullptr; @@ -735,11 +735,11 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( // operand = s32[3,5,2] constant({...}) // indices = s32[7] parameter(0) // gather = s32[3,2,7] gather(operand, indices), - // output_window_dims={0,1}, - // elided_window_dims={1}, - // gather_dims_to_operand_dims={1}, + // offset_dims={0,1}, + // collapsed_slice_dims={1}, + // start_index_map={1}, // index_vector_dim=1, - // window_bounds={3,1,2} + // slice_sizes={3,1,2} // reshape = s32[6,7] reshape(gather) // // In this case the gather maps to: @@ -754,9 +754,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( if (source_dim_for_new_scalar_indexed_node == -1) { VLOG(3) << "Could not compute the source dim for the new scalar indexed " "node: scalar_indexed_source_shape = [" - << Join(scalar_indexed_source_shape.dimensions(), ",") + << StrJoin(scalar_indexed_source_shape.dimensions(), ",") << "] and new_scalar_indexed_source_shape = [" - << Join(new_scalar_indexed_source_shape, ",") << "]"; + << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -764,8 +764,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, - std::multiplies()), + CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); CHECK(IsReshapePassthroughOperandDim( @@ -781,9 +781,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( }; std::vector output_dims_for_new_scalar_indexed_node; - c_transform(scalar_indexed->output_dims(), - std::back_inserter(output_dims_for_new_scalar_indexed_node), - map_passthrough_operand_dim_to_result_dim); + absl::c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, TakeOwnership(scalar_indexed->literal().Reshape( @@ -872,13 +872,14 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, return nullptr; } - ArraySlice broadcast_dims = broadcast_instr->dimensions(); + absl::Span broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { - return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; // All of the output dims must be "broadcasted" dims for the other operand. - if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + if (!absl::c_all_of(scalar_indexed_const->output_dims(), + is_broadcasted_dim)) { return nullptr; } @@ -894,7 +895,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // The scalar-indexed node "removes" the source dim and "inserts" the output // dims. We do the opposite here to undo the scalar-indexed operation. - ArraySlice output_dims = scalar_indexed_const->output_dims(); + absl::Span output_dims = scalar_indexed_const->output_dims(); for (int64 i = output_dims.size() - 1; i >= 0; --i) { CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); EraseAt(&simulated_index, output_dims[i]); @@ -970,15 +971,15 @@ namespace { // Returns the non-contracting non-batch dimension (as per `contracting_dims` // and `batch_dims`) if there is exactly one, otherwise returns nullopt. -gtl::optional GetOnlyNonContractingNonBatchDim( - int64 rank, ArraySlice contracting_dims, - ArraySlice batch_dims) { - gtl::optional result; +absl::optional GetOnlyNonContractingNonBatchDim( + int64 rank, absl::Span contracting_dims, + absl::Span batch_dims) { + absl::optional result; for (int64 dim = 0; dim < rank; dim++) { - if (!ArrayContains(contracting_dims, dim) && - !ArrayContains(batch_dims, dim)) { + if (!absl::c_linear_search(contracting_dims, dim) && + !absl::c_linear_search(batch_dims, dim)) { if (result.has_value()) { - return gtl::nullopt; + return absl::nullopt; } result = dim; } @@ -995,10 +996,10 @@ gtl::optional GetOnlyNonContractingNonBatchDim( // `contracting_dims` and `batch_dims` are the contracting and batch dimensions // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( - tensorflow::StringPiece tag, - Analysis::ScalarIndexedConstantArray* indexed_array, - ArraySlice contracting_dims, ArraySlice batch_dims) { - gtl::optional non_contracting_non_batch_dim = + absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, + absl::Span contracting_dims, + absl::Span batch_dims) { + absl::optional non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { @@ -1133,7 +1134,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( return nullptr; } -tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { +absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index e923dc39f7f464a8d3c400294499a6f5efda3991..dcfb7255358ae08660fe2c6eae5af9f10370e762 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -188,9 +188,7 @@ class IndexedArrayAnalysis { // `output_dims` are the dimensions in the output array that are being used // to compute an index into the `indices` array. See the class // documentation and the overview for more details. - tensorflow::gtl::ArraySlice output_dims() const { - return output_dims_; - } + absl::Span output_dims() const { return output_dims_; } private: explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, @@ -265,8 +263,7 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, Array* source, - Array* indices); + absl::Span slice_sizes, Array* source, Array* indices); StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, @@ -303,7 +300,7 @@ class IndexedArrayAnalysis { // G1 = [Arr[i] for i in I2] StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, - tensorflow::gtl::ArraySlice output_dims, Shape shape); + absl::Span output_dims, Shape shape); // Reshapes a scalar-indexed node to remove the degenerate dimensions in its // output. The result is always a scalar-indexed node. @@ -313,8 +310,7 @@ class IndexedArrayAnalysis { // Reshapes a scalar-indexed node such that the result has the degenerate // dimensions `degenerate_dims`. The result is always a scalar-indexed node. StatusOr ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, - tensorflow::gtl::ArraySlice degenerate_dims); + ScalarIndexedArray* operand, absl::Span degenerate_dims); StatusOr FoldReshapeOfGather( const Shape& shape, ScalarIndexedConstantArray* operand); @@ -371,7 +367,7 @@ class IndexedArrayAnalysis { // unconditionally add to the regular HLO pass pipeline. class IndexedArrayAnalysisPrinterPass : public HloPassInterface { public: - tensorflow::StringPiece name() const override; + absl::string_view name() const override; StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 5f4b42799b1c26ea544f9d4447cc45b5ae9d5a48..2d03aebc1aca4c55cca588072233b7a18e70a306 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -82,11 +82,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -102,11 +102,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5] parameter(0) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -122,11 +122,11 @@ ENTRY main { operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) indices = s32[5,2] parameter(0) ROOT gather = s32[5] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} } )"; @@ -141,11 +141,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,2}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0,2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3,1} + slice_sizes={1,3,1} } )"; @@ -160,11 +160,11 @@ ENTRY main { operand = s32[3,3,1] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2,3] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={0}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} } )"; @@ -179,11 +179,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[5] parameter(1) ROOT gather = s32[5,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,2} + slice_sizes={1,2} } )"; @@ -199,17 +199,17 @@ ENTRY main { indices_a = s32[5] parameter(0) indices_b = s32[2] parameter(1) gather_a = s32[5,3] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT gather_b = s32[2,3] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} } )"; @@ -228,17 +228,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[2] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -256,17 +256,17 @@ ENTRY main { indices_a = s32[2] parameter(1) indices_b = s32[5,7] parameter(2) gather_a = s32[2,6] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} } )"; @@ -284,17 +284,17 @@ ENTRY main { indices_a = s32[5,7] parameter(1) indices_b = s32[4,8] parameter(2) gather_a = s32[5,3,7] gather(operand, indices_a), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3,1} + slice_sizes={3,1} ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), - output_window_dims={1,2}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={1,2}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=2, - window_bounds={5,3,1} + slice_sizes={5,3,1} } )"; @@ -312,11 +312,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5] parameter(0) gather = s32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2] reshape(gather) } )"; @@ -333,11 +333,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,7] parameter(0) gather = s32[5,4,7] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,7] reshape(gather) } )"; @@ -358,11 +358,11 @@ ENTRY main { {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[5,7] parameter(0) gather = s32[5,2,6,7] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,2,6} + slice_sizes={1,2,6} ROOT reshape = s32[5,3,4,7] reshape(gather) } )"; @@ -381,11 +381,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -408,14 +408,14 @@ ENTRY main { operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) i.0 = s64[1,3]{1,0} parameter(0) - g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, - elided_window_dims={0}, gather_dims_to_operand_dims={0}, - index_vector_dim=2, window_bounds={1,3} + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), offset_dims={2}, + collapsed_slice_dims={0}, start_index_map={0}, + index_vector_dim=2, slice_sizes={1,3} i.1 = s64[1] parameter(1) - g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, - elided_window_dims={1}, gather_dims_to_operand_dims={1}, - index_vector_dim=1, window_bounds={1,1,3} + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), offset_dims={0,2}, + collapsed_slice_dims={1}, start_index_map={1}, + index_vector_dim=1, slice_sizes={1,1,3} ROOT reshape = s32[1,3]{1,0} reshape(g.1) } @@ -441,11 +441,11 @@ ENTRY main { operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) indices = s32[1] parameter(0) gather = s32[1,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,6] reshape(gather) } )"; @@ -469,11 +469,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}}) indices = s32[1] parameter(0) gather = s32[1,1,6] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1,2}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={1,1,6} + slice_sizes={1,1,6} ROOT reshape = s32[1,1,1,6] reshape(gather) } )"; @@ -500,11 +500,11 @@ ENTRY main { {1,2,3,4,5,6},{1,2,3,4,5,6}}) indices = s32[1,5] parameter(0) gather = s32[1,5,6] gather(operand, indices), - output_window_dims={2}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={2}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,6} + slice_sizes={1,6} ROOT reshape = s32[1,1,5,6] reshape(gather) } )"; @@ -530,11 +530,11 @@ ENTRY main { operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4} + slice_sizes={1,4} ROOT reshape = s32[5,2,2,2,3] reshape(gather) } )"; @@ -562,11 +562,11 @@ ENTRY main { {{1,2},{3,4},{5,6},{7,8},{9,10}}}) indices = s32[7] parameter(0) gather = s32[3,2,7] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0,1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1,2} + slice_sizes={3,1,2} ROOT reshape = s32[6,7] reshape(gather) } )"; @@ -594,11 +594,11 @@ ENTRY main { {{1},{2},{3},{4}}}) indices = s32[5,6] parameter(0) gather = s32[5,4,6,1] gather(operand, indices), - output_window_dims={1,3}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1,3}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=2, - window_bounds={1,4,1} + slice_sizes={1,4,1} ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) } )"; @@ -623,20 +623,20 @@ ENTRY main { operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) indices = s32[5] parameter(0) gather = f32[5,4] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT tanh = f32[5,4] tanh(gather) } )"; AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( (scalar-indexed-const (constant f32[3,4] f32[3,4] { - { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, - { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, - { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } + { 0.761594, 0.964028, 0.995055, 0.999329 }, + { 0.761594, 0.995055, 0.964028, 0.999329 }, + { 0.999329, 0.995055, 0.964028, 0.761594 } }) %indices 0->[0]))"); } @@ -650,11 +650,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -678,11 +678,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) } )"; @@ -706,11 +706,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) } )"; @@ -733,11 +733,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -760,11 +760,11 @@ ENTRY main { constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} indices = s32[5] parameter(0) gather = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT add = s32[5,4] add(gather, constant_broadcasted) } )"; @@ -808,11 +808,11 @@ ENTRY main { dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_lhs = s32[5,4] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,4} + slice_sizes={1,4} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -835,11 +835,11 @@ ENTRY main { dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}}) indices = s32[5] parameter(0) dot_lhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0} } )"; @@ -863,11 +863,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[3,5] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; @@ -892,11 +892,11 @@ ENTRY main { dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}}) indices = s32[5] parameter(0) dot_rhs = s32[5,3] gather(gather_operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1,3} + slice_sizes={1,3} ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1} } )"; @@ -921,11 +921,11 @@ ENTRY main { dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}}) indices = s32[4] parameter(0) dot_rhs = s32[2,3,4] gather(gather_operand, indices), - output_window_dims={0,1}, - elided_window_dims={2}, - gather_dims_to_operand_dims={2}, + offset_dims={0,1}, + collapsed_slice_dims={2}, + start_index_map={2}, index_vector_dim=1, - window_bounds={2,3,1} + slice_sizes={2,3,1} ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} @@ -952,11 +952,11 @@ ENTRY main { dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}}) indices = s32[2] parameter(0) dot_lhs = s32[3,2] gather(gather_operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3,1} + slice_sizes={3,1} ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0} } )"; diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 5c193fceb984448cf0532d7e1010281268614293..5fd779ebf9b59e34a0844cc3a898bb72ce6044ee 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index a523811f6c141a7dc24b1c88897d82d046aa1a2d..efa8ed3abcc6cd7cd8d31ec2170eae8752988c09 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -27,7 +27,7 @@ namespace xla { class Inliner : public HloPassInterface { public: ~Inliner() override = default; - tensorflow::StringPiece name() const override { return "inline"; } + absl::string_view name() const override { return "inline"; } // Run inlining on the given computation. Returns whether the computation was // changed. diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 32937b33b3737482f07d4c7607f7f1c5c183a56b..5695bc242057c037a1999e7d63f5b4f21b5f658a 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f33942d67907d8f40811bde5041350a2e1e1f1fc..8c907eae0cbe7c3764a2bfe8fed6b6098931de38 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -121,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -130,7 +132,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: - case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kMap: @@ -171,7 +172,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast) { + if (operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kIota) { return false; } if (operand->opcode() == HloOpcode::kConstant && @@ -189,13 +191,13 @@ bool InstructionFusion::CanFuseOnAllPaths( if (consumer == producer) { return true; } - if (!consumer->IsFusable()) { + if (!consumer->IsFusible()) { return false; } for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter - // whether it's fusable. + // whether it's fusible. if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } @@ -205,7 +207,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. + // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { @@ -216,8 +218,8 @@ bool InstructionFusion::CanFuseOnAllPaths( } InstructionFusion::HloInstructionSet -InstructionFusion::ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order) { +InstructionFusion::ComputeGloballyUnfusible( + absl::Span post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -270,19 +272,19 @@ InstructionFusion::ComputeGloballyUnfusable( // all of its consumers on all paths. // // That means, that for: - // A --> B (fusable) - // \-> C (non-fusable) + // A --> B (fusible) + // \-> C (non-fusible) // A will be not allowed to be fused into B, as it cannot be fused into C. // // Similarly, for: // A -------------> B // \-> C -> D -/ // If: - // - A is fusable into B and C, and D is fusable into B - // - C is *not* fusable into D + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D // A will be not allowed to be fused into B, as it cannot be fused via // all paths. - if (producer->IsFusable() && + if (producer->IsFusible() && CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } @@ -318,7 +320,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -341,7 +343,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { // consistent. post_order_index.erase(instruction); - if (!instruction->IsFusable() && + if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -413,7 +415,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (!operand->IsFusable()) { + if (!operand->IsFusible()) { continue; } @@ -497,7 +499,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return c_any_of( + return absl::c_any_of( consumer->operands(), [&](const HloInstruction* consumer_operand) { // The fusion algorithm traverses the HLO graph in reverse post order. // Thus `cosumers` is visited before its operands (including diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f73ca9adf768ed26f9ec9f162e01b7b160f50daf..00b658959a2cceeb30d2ec03f243119ec0a8ee47 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface { bool may_duplicate = true) : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} ~InstructionFusion() override = default; - tensorflow::StringPiece name() const override { return "fusion"; } + absl::string_view name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). @@ -122,8 +122,8 @@ class InstructionFusion : public HloPassInterface { // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - HloInstructionSet ComputeGloballyUnfusable( - tensorflow::gtl::ArraySlice post_order); + HloInstructionSet ComputeGloballyUnfusible( + absl::Span post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9e7a15f0330d3f06779c850a4b575f84fe0b9505..da1ad90959dc0ab1a840b3390281ce9d4999651e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); auto param0 = @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // // p0 -> add -------------------------> sub @@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); // A variant of the above that allows the algorithm to put add2 into the set - // of unfusable ops to short-circuit the decision whether add1 should be fused + // of unfusible ops to short-circuit the decision whether add1 should be fused // into sub2. // // /---------------\ diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8652599dc6d48ff8c2aaa703fead161f891a57d1..146c9052f10cca8b199a480491d9a672d8bebdff 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -12,12 +12,11 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -32,8 +31,6 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -54,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains compiler registration ) @@ -79,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", @@ -91,6 +88,8 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -116,5 +115,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c..bb69cb9c47ff2c7de8d13832c4b8e6216c62da73 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -69,8 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module), - xla::MakeUnique()); + absl::make_unique( + std::move(hlo_module), absl::make_unique()); return std::move(executable); } @@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique(); + return absl::make_unique(); }); xla::ComputationPlacer::RegisterComputationPlacer( se::interpreter::kXlaInterpreterPlatformId, - []() { return xla::MakeUnique(); }); + []() { return absl::make_unique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 8d40c08d555a232b7cf3b81cc0f9970804c2f896..5dea12476849db6f7a9a9214398b4e57262aeda0 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" @@ -47,7 +47,7 @@ InterpreterExecutable::~InterpreterExecutable() {} StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); se::StreamExecutor* executor = stream->parent(); @@ -111,7 +111,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 91d8148d26dc8eddbafdaf4870d9efbb73a12816..3b1ebce0c75457d65e6834c809fe488a9c4a159a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -48,13 +48,13 @@ class InterpreterExecutable : public Executable { StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; + absl::Span arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index db6b910b32f8ec234c4cf1c331a1aa3bb2f9389f..fbb99457847dca69a1901006d5d8ff713882f918 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/device_description.h" @@ -47,7 +47,7 @@ limitations under the License. namespace stream_executor { namespace interpreter { -using Args = tensorflow::gtl::ArraySlice; +using Args = absl::Span; class XlaInterpreterExecutor : public internal::StreamExecutorInterface { public: diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index d27cd7502f10a1f615fc5b0d610acafdf55e3e43..7955ee5cf37f3fa45b942d8ab05a60076857dc6c 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager() static std::unique_ptr CreateInterpreterTransferManager() { - return xla::MakeUnique(); + return absl::make_unique(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h index 2b44f308218e2f61f08012769246b8a0e9639822..b732230fdd88b694f21ad5bc03d373331f8fb8f9 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/core/platform/macros.h" @@ -33,4 +33,4 @@ class InterpreterTransferManager : public GenericTransferManager { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 42c2c28997d5f3b02f1fe4effca164c893e4071d..c9b40d3c6195f80a19272a0d98890049d02315b9 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -17,13 +17,14 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" -#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -70,15 +71,15 @@ port::StatusOr XlaInterpreterPlatform::GetExecutor( port::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = MakeUnique( - this, MakeUnique(config.plugin_config)); + auto executor = absl::make_unique( + this, absl::make_unique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ port::error::INTERNAL, - port::Printf( + absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; + config.ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 805fdb2d5bd8a08490b354d60f281c8f99bc20d8..6e17711f575b24ffcfcbf1a78bb803603b001adf 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,9 +26,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -48,21 +52,11 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, } string BufferLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", - buffer_->ToString().c_str(), - LayoutUtil::HumanString(layout_).c_str()); + return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(), + LayoutUtil::HumanString(layout_)); } OperandLayoutConstraint::OperandLayoutConstraint( @@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint( } string OperandLayoutConstraint::ToString() const { - return tensorflow::strings::Printf( - "OperandLayoutConstraint %s, operand %lld: %s", - instruction_->name().c_str(), operand_no_, - shape_layout_.ToString().c_str()); + return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s", + instruction_->name(), operand_no_, + shape_layout_.ToString()); } string ResultLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("ResultLayoutConstraint: %s", - shape_layout_.ToString().c_str()); + return absl::StrFormat("ResultLayoutConstraint: %s", + shape_layout_.ToString()); } LayoutConstraints::LayoutConstraints( @@ -137,7 +129,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( } auto& buffer_set = buffer_sets_cache_ - .emplace(instruction, MakeUnique()) + .emplace(instruction, absl::make_unique()) .first->second; const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); points_to_set.ForEachElement( @@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Layout of buffer %s cannot be constrained because buffer is not " "array-shaped, has shape: %s", - buffer.ToString().c_str(), - ShapeUtil::HumanString(buffer.shape()).c_str()); + buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); @@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", - buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint.layout()).c_str(), - LayoutUtil::HumanString(layout).c_str()); + buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()), + LayoutUtil::HumanString(layout)); } iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } else { @@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } if (curr_shape_layout->mandatory()) { return FailedPrecondition( - "Operand %lld of instruction %s already has a layout constraint " + "Operand %d of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", - operand_no, instruction->name().c_str(), - curr_shape_layout->shape_layout().ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + operand_no, instruction->name(), + curr_shape_layout->shape_layout().ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } } @@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, // layouts beyond this immediate use and is complicated to handle. if (OperandBufferForwarded(instruction, operand_no)) { return FailedPrecondition( - "Cannot constraint layout of operand %lld of instruction %s " + "Cannot constraint layout of operand %d of instruction %s " "because instruction forwards operand's LogicalBuffer(s)", - operand_no, instruction->name().c_str()); + operand_no, instruction->name()); } auto key = std::make_pair(instruction, operand_no); @@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", - computation_->name().c_str(), curr_shape_layout->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + computation_->name(), curr_shape_layout->ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // New constraint matches existing constraint. Nothing to do. return Status::OK(); @@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout( if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { return FailedPrecondition( "Instruction %s of shape %s cannot be assigned incompatible layout %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // Create a BufferLayoutConstraint for each array shape in the output of the @@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const { string LayoutConstraints::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", - computation_->name(), ":\n"); + absl::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); for (auto* instruction : computation_->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), - "\n"); + absl::StrAppend(&output, " ", instruction->ToShortString(), "\n"); for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { - tensorflow::strings::StrAppend( - &output, " operand (", i, - "): ", OperandLayout(instruction, i)->ToString(), "\n"); + absl::StrAppend(&output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { if (BufferLayout(*buffer) != nullptr) { - tensorflow::strings::StrAppend( - &output, " ", buffer->ToString(), " : ", - LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + absl::StrAppend(&output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); } } } if (ResultLayout() != nullptr) { - tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), - "\n"); + absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n"); } return output; } @@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter, return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", - parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + parameter->ToString(), parameter_layout.ToString()); } return Status::OK(); } @@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) { constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", - constant->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + constant->ToString(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape())); } return Status::OK(); } @@ -908,13 +893,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str(), - buffer->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction_subshape) - .c_str(), - ShapeUtil::HumanStringWithLayout(buffer->shape()) - .c_str()); + instruction->name(), absl::StrJoin(index, ","), + buffer->ToString(), + ShapeUtil::HumanStringWithLayout(instruction_subshape), + ShapeUtil::HumanStringWithLayout(buffer->shape())); } } } @@ -998,17 +980,18 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && + if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape())) { - // Assign operands the same layout as the instruction, so that + ShapeUtil::Rank(instruction->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + // Propagate the result layout to the operand layout if the instruction + // requires the same layout out for the result and the operand. + // + // For elementwise operations, using the same layout for the operands and + // the result also has the following benefits: // 1) the elementwise operation can reuse its operand's buffer, and // 2) the input and output elements can reuse the same linear index. - // - // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit - // from assigning the same layout to input and output. - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } if (instruction->opcode() == HloOpcode::kReshape) { @@ -1031,13 +1014,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(operand_shape.layout()); + return absl::make_unique(operand_shape.layout()); } if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique(output_layout); + return absl::make_unique(output_layout); } } auto aligned_operand_shape = @@ -1046,7 +1029,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } @@ -1062,7 +1045,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } return nullptr; @@ -1076,11 +1059,11 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( CHECK(ShapeUtil::IsArray(user->shape()) && ShapeUtil::IsArray(operand->shape())); - if (user->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(user)) { // Assign users the same layout as the operand. - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } if (user->opcode() == HloOpcode::kReshape) { @@ -1103,13 +1086,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(output_shape.layout()); + return absl::make_unique(output_shape.layout()); } if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique(operand_layout); + return absl::make_unique(operand_layout); } } auto aligned_user_shape = @@ -1118,7 +1101,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } } @@ -1134,7 +1117,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); - return MakeUnique(user_layout); + return absl::make_unique(user_layout); } return nullptr; @@ -1385,7 +1368,7 @@ StatusOr InferArrayLayout( // This should not happen because we've assigned layouts to all // instructions preceding this one. return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString().c_str()); + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -1400,9 +1383,8 @@ StatusOr InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - tensorflow::str_util::Join(index, ",").c_str(), - instruction->name().c_str(), source_buffers[0]->ToString().c_str(), - source_buffer->ToString().c_str()); + absl::StrJoin(index, ","), instruction->name(), + source_buffers[0]->ToString(), source_buffer->ToString()); } } @@ -1570,7 +1552,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // present in the IR before layout assignment is a bug. return InternalError( "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + instruction->ToString()); } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); @@ -1822,6 +1804,107 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } +bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCustomCall: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFft: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kLt: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return true; + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopy: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kAfterAll: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return false; + } +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index f9e8dbea2f8aa224318adf3cf4b5e493792d3093..cf545031d3c7c66770ea4a2392a2df3b8c24cd38 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} - tensorflow::StringPiece name() const override { return "layout-assignment"; } + absl::string_view name() const override { return "layout-assignment"; } // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). StatusOr Run(HloModule* module) override; + // Returns true if the instruction requires that operands with the same rank + // as the output have to have the same layout as the output. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a16fa75e3032cfa4257d9b5608dd176fdb4ddbdb..021fe630ff6329c51e297d0bb2bee8269a42904b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase { EXPECT_IS_OK(layout_assignment.Run(module).status()); } - std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + std::vector LayoutOf(HloModule* module, absl::string_view name) { auto minor_to_major = FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); @@ -861,5 +861,115 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopySliceOperandToAvoidImplicitLayoutChange + + ENTRY CopySliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]} + ROOT add0 = f32[3,4]{1,0} add(par0,slice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto slice = FindInstruction(module.get(), "slice0"); + EXPECT_EQ(slice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyDSliceOperandToAvoidImplicitLayoutChange + + ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + par2 = s32[2] parameter(2) + dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto dslice = FindInstruction(module.get(), "dslice0"); + EXPECT_EQ(dslice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyConcatOperandToAvoidImplicitLayoutChange + + ENTRY CopyConcatOperandToAvoidImplicitLayoutChange { + par0 = f32[3,8]{1,0} parameter(0) + par1 = f32[3,5]{0,1} parameter(1) + par2 = f32[3,3]{1,0} parameter(2) + concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2), + dimensions={1} + ROOT add0 = f32[3,8]{1,0} add(par0,concat0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto concat = FindInstruction(module.get(), "concat0"); + EXPECT_EQ(concat->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, + ConvolutionOperandWithImplicitLayoutChangeNotCopied) { + const char* module_str = R"( + HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied + + ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied { + par0 = f32[128,3,230,230]{2,3,1,0} parameter(0) + par1 = f32[7,7,3,64]{3,2,0,1} parameter(1) + ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1), + window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01, + feature_group_count=1 + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + EXPECT_EQ(copy, nullptr); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index cdd3daf73b8ac1a4d1ec3c81224c2c0bfe8e5811..540bbb7c7a74f65ab70f4c6704d6600db2adbb60 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -69,6 +70,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -88,6 +91,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -103,6 +109,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -120,6 +128,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -133,9 +142,7 @@ cc_library( ":llvm_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@llvm//:core", @@ -159,6 +166,7 @@ cc_library( "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -193,7 +201,10 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@llvm//:core", + "@llvm//:support", ], ) @@ -208,6 +219,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", "@llvm//:core", ], ) @@ -219,7 +231,7 @@ cc_library( deps = [ ":llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -230,6 +242,7 @@ cc_library( hdrs = ["buffer_assignment_util.h"], deps = [ "//tensorflow/compiler/xla/service:buffer_assignment", + "@com_google_absl//absl/strings", ], ) @@ -242,3 +255,12 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "ir_builder_mixin", + srcs = [], + hdrs = ["ir_builder_mixin.h"], + deps = [ + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index fe9eab93aae95557e3ee27a64c09b78f37ac2348..8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index fe5ec1cc66d06e85ce70625ef7cf764a37b29166..b6ae4932f5707f1d15af1e09a735a7de2e48fac5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -61,7 +61,7 @@ ENTRY while3 { ; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]] ; ; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params -; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0 +; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %buffer_table, i64 0 ; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]] ; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float* ; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]] diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 4eb5d9fb4750927ca189e02f312b2d6be7fdd418..bdce4a171b8a58f617f1d56e6cf6db5354846703 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "absl/strings/str_cat.h" namespace xla { namespace llvm_ir { @@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName( c = '_'; } } - return tensorflow::strings::StrCat("buffer_for_", instr_name); + return absl::StrCat("buffer_for_", instr_name); } const Literal& LiteralForConstantAllocation( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 27fbb11e2ede66a1268e7e949634b2c7d29cbc1c..cc2e862f2eb9a49099c5f90efe1b29fb77c8f106 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* b) { + absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. @@ -99,10 +99,10 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( return LoopEmitter(loop_body_emitter, update_shape, b).EmitLoop(name); } -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b) { +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the @@ -130,8 +130,7 @@ Status EmitDynamicUpdateSliceInPlace( // // Emits a sequential loop if launch_dimensions is null. static Status EmitFusedDynamicUpdateSliceInPlaceImpl( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions* launch_dimensions, llvm::IRBuilder<>* b) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); @@ -174,8 +173,7 @@ static Status EmitFusedDynamicUpdateSliceInPlaceImpl( } Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( @@ -184,8 +182,7 @@ Status EmitFusedDynamicUpdateSliceInPlace( } Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b) { return EmitFusedDynamicUpdateSliceInPlaceImpl( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 3502577d236a099e0b721b98217b758696966821..fb3e4eb97cae06f2a0c87dd7118b8332048df56e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -63,26 +63,24 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // Emits IR for running the given dynamic-update-slice op in-place -- that is, // where the input and output buffers share the same slice, so we can simply // modify the input/output buffer without touching any of the other elements. -Status EmitDynamicUpdateSliceInPlace( - tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b); +Status EmitDynamicUpdateSliceInPlace(absl::Span operand_arrays, + const IrArray& output_array, + absl::string_view name, + llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits // (sequential) code for a fusion node that does the dynamic-update-slice in // place. Status EmitFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, llvm::IRBuilder<>* b); // Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with // the given launch dimensions. Status EmitParallelFusedDynamicUpdateSliceInPlace( - HloInstruction* fusion, - tensorflow::gtl::ArraySlice fusion_operand_arrays, + HloInstruction* fusion, absl::Span fusion_operand_arrays, const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, const gpu::LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* b); diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 72ede377e1a505d5e4916915e18827e1a0f3fdf9..b606c993a2d58a6d177af10de7b214de130c2279 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement( return Unimplemented( "GetTupleElement fusion currently only supports" " parameter operands, but found operand: %s", - operand->name().c_str()); + operand->name()); } // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( @@ -147,7 +147,7 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { } Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 30471480c4fb3ce3bf3226a28e9d2ffa79ae5f29..44d21fa750a532633f46614002d59c90fc0b5d40 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -54,7 +54,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; - FusedIrEmitter(tensorflow::gtl::ArraySlice parameter_arrays, + FusedIrEmitter(absl::Span parameter_arrays, ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), tiled_parameter_info_(nullptr), @@ -94,7 +94,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { private: // Arrays of parameters of fusion instruction - tensorflow::gtl::ArraySlice parameter_arrays_; + absl::Span parameter_arrays_; const llvm_ir::TiledParameterInfo* tiled_parameter_info_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d..67f7423121177e2ca1e3384341dad2644c8f5e34 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -73,7 +73,7 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, Delinearize(&multidim_, linear, shape, b); } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, llvm::Value* linear, const Shape& shape) : multidim_(multidim.begin(), multidim.end()), linear_(linear), @@ -92,7 +92,7 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, << " should have a layout."; } -IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, +IrArray::Index::Index(absl::Span multidim, const Shape& shape, llvm::IRBuilder<>* b) : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), @@ -147,16 +147,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = - Index(tensorflow::gtl::ArraySlice( - multidim_, common_factors[k].second, + Index(absl::Span(multidim_).subspan( + common_factors[k].second, common_factors[k + 1].second - common_factors[k].second), index_type_) - .Linearize( - tensorflow::gtl::ArraySlice( - AsInt64Slice(output_shape.dimensions()), - common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second), - builder); + .Linearize(AsInt64Slice(output_shape.dimensions()) + .subspan(common_factors[k].second, + common_factors[k + 1].second - + common_factors[k].second), + builder); // Delinearizes logical_linear_index for the source array in row-major // collapsed order. The first rank-1 indices are the remainder of the // linear index by each dimension size. @@ -185,9 +184,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } IrArray::Index IrArray::Index::SourceIndexOfSlice( - const Shape& shape, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, - llvm::IRBuilder<>* builder) const { + const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const { Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; @@ -208,7 +206,7 @@ IrArray::Index IrArray::Index::SourceIndexOfSlice( IrArray::Index IrArray::Index::SourceIndexOfTranspose( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); @@ -257,7 +255,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( IrArray::Index IrArray::Index::SourceIndexOfBroadcast( const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, + absl::Span dimension_mapping, llvm::IRBuilder<>* builder) const { int64 rank = ShapeUtil::Rank(operand_shape); std::vector source_index(rank); @@ -322,9 +320,8 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( return Index(source_index, linear, operand_shape); } -llvm::Value* IrArray::Index::Linearize( - tensorflow::gtl::ArraySlice dimensions, - llvm::IRBuilder<>* builder) const { +llvm::Value* IrArray::Index::Linearize(absl::Span dimensions, + llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. CHECK_EQ(size(), dimensions.size()); @@ -342,9 +339,9 @@ llvm::Value* IrArray::Index::Linearize( return logical_linear_index; } -llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { +llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, + llvm::IRBuilder<>* b, + absl::string_view name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value @@ -402,7 +399,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { + absl::string_view name) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 28ca793e3eeaed86664bfa6aa859a38f2c4dc6f3..f4b05f29c38529b3cce81b4c8ee6fae5c00cafcc 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -19,13 +19,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -69,7 +70,7 @@ class IrArray { // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim, + explicit Index(absl::Span multidim, llvm::Type* index_ty = nullptr) : multidim_(multidim.begin(), multidim.end()) { if (size() == 0) { @@ -81,7 +82,7 @@ class IrArray { } } CHECK_NE(index_type_, nullptr); - CHECK(c_all_of(multidim, [&](llvm::Value* v) { + CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { return index_type_ == v->getType(); })); } @@ -98,14 +99,14 @@ class IrArray { // that it indexes into. // // Precondition: "shape" has a layout. - Index(tensorflow::gtl::ArraySlice multidim, - const Shape& shape, llvm::IRBuilder<>* b); + Index(absl::Span multidim, const Shape& shape, + llvm::IRBuilder<>* b); // Constructs an index from both a multi-dimensional index and a linear // index. "shape" has the same meaning as that in the constructor that takes // only a linear index. - Index(tensorflow::gtl::ArraySlice multidim, - llvm::Value* linear, const Shape& shape); + Index(absl::Span multidim, llvm::Value* linear, + const Shape& shape); const std::vector& multidim() const { return multidim_; } llvm::Value* linear() const { return linear_; } @@ -144,17 +145,15 @@ class IrArray { // by starting indices `starts` and stride values `strides`. // // Precondition: "this" is an index into a slice whose shape is `shape`. - Index SourceIndexOfSlice(const Shape& shape, - tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice strides, + Index SourceIndexOfSlice(const Shape& shape, absl::Span starts, + absl::Span strides, llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a transpose from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfTranspose( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfTranspose(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Given that "this" is the target index of a bitcast from `operand_shape` // to `shape`, returns the source index. @@ -163,14 +162,13 @@ class IrArray { // Given that "this" is the target index of a broadcast from `operand_shape` // to `shape` with the given dimension mapping, returns the source index. - Index SourceIndexOfBroadcast( - const Shape& shape, const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimension_mapping, - llvm::IRBuilder<>* builder) const; + Index SourceIndexOfBroadcast(const Shape& shape, const Shape& operand_shape, + absl::Span dimension_mapping, + llvm::IRBuilder<>* builder) const; // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and // returns the index into the sole dimension 0 of the new shape. - llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, + llvm::Value* Linearize(absl::Span dimensions, llvm::IRBuilder<>* builder) const; llvm::Type* GetType() const { return index_type_; } @@ -240,7 +238,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -254,7 +252,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h new file mode 100644 index 0000000000000000000000000000000000000000..abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -0,0 +1,400 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ + +#include "llvm/IR/IRBuilder.h" + +namespace xla { + +// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods +// into a class. Intended to be used as a CRTP base class, like: +// +// class MyIrEmitter : public IrBuilderMixin { +// llvm::IRBuilder<>* builder() { return builder_; } +// +// void EmitFoo(HloInstruction* foo) { +// Add(Mul(...), FPToUI(...)); +// } +// }; + +template +class IrBuilderMixin { + protected: + template + llvm::Value* Add(Args&&... args) { + return mixin_builder()->CreateAdd(std::forward(args)...); + } + + template + llvm::LoadInst* AlignedLoad(Args&&... args) { + return mixin_builder()->CreateAlignedLoad(std::forward(args)...); + } + + template + llvm::StoreInst* AlignedStore(Args&&... args) { + return mixin_builder()->CreateAlignedStore(std::forward(args)...); + } + + template + llvm::AllocaInst* Alloca(Args&&... args) { + return mixin_builder()->CreateAlloca(std::forward(args)...); + } + + template + llvm::Value* And(Args&&... args) { + return mixin_builder()->CreateAnd(std::forward(args)...); + } + + template + llvm::Value* AtomicCmpXchg(Args&&... args) { + return mixin_builder()->CreateAtomicCmpXchg(std::forward(args)...); + } + + template + llvm::Value* AtomicRMW(Args&&... args) { + return mixin_builder()->CreateAtomicRMW(std::forward(args)...); + } + + template + llvm::Value* BitCast(Args&&... args) { + return mixin_builder()->CreateBitCast(std::forward(args)...); + } + + template + llvm::Value* Br(Args&&... args) { + return mixin_builder()->CreateBr(std::forward(args)...); + } + + llvm::CallInst* Call(llvm::Value* callee, + llvm::ArrayRef args = llvm::None, + const llvm::Twine& name = "", + llvm::MDNode* fp_math_tag = nullptr) { + return mixin_builder()->CreateCall(callee, args, name, fp_math_tag); + } + + template + llvm::BranchInst* CondBr(Args&&... args) { + return mixin_builder()->CreateCondBr(std::forward(args)...); + } + + template + llvm::Value* ConstInBoundsGEP1_32(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_32( + std::forward(args)...); + } + + template + llvm::Value* FAdd(Args&&... args) { + return mixin_builder()->CreateFAdd(std::forward(args)...); + } + + template + llvm::Value* FMul(Args&&... args) { + return mixin_builder()->CreateFMul(std::forward(args)...); + } + + llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateGEP(ptr, idx_list, name); + } + + template + llvm::Value* ICmpEQ(Args&&... args) { + return mixin_builder()->CreateICmpEQ(std::forward(args)...); + } + + template + llvm::Value* ICmpNE(Args&&... args) { + return mixin_builder()->CreateICmpNE(std::forward(args)...); + } + + template + llvm::Value* ICmpULE(Args&&... args) { + return mixin_builder()->CreateICmpULE(std::forward(args)...); + } + + template + llvm::Value* ICmpULT(Args&&... args) { + return mixin_builder()->CreateICmpULT(std::forward(args)...); + } + + llvm::Value* InBoundsGEP(llvm::Value* ptr, + llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name); + } + + llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateExtractValue(agg, idxs, name); + } + + llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val, + llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInsertValue(agg, val, idxs, name); + } + + template + llvm::Value* IntToPtr(Args&&... args) { + return mixin_builder()->CreateIntToPtr(std::forward(args)...); + } + + template + llvm::LoadInst* Load(Args&&... args) { + return mixin_builder()->CreateLoad(std::forward(args)...); + } + + template + llvm::CallInst* MemCpy(Args&&... args) { + return mixin_builder()->CreateMemCpy(std::forward(args)...); + } + + template + llvm::Value* Mul(Args&&... args) { + return mixin_builder()->CreateMul(std::forward(args)...); + } + + template + llvm::Value* NSWAdd(Args&&... args) { + return mixin_builder()->CreateNSWAdd(std::forward(args)...); + } + + template + llvm::Value* NSWMul(Args&&... args) { + return mixin_builder()->CreateNSWMul(std::forward(args)...); + } + + template + llvm::Value* NSWSub(Args&&... args) { + return mixin_builder()->CreateNSWSub(std::forward(args)...); + } + + template + llvm::Value* Or(Args&&... args) { + return mixin_builder()->CreateOr(std::forward(args)...); + } + + template + llvm::Value* PointerCast(Args&&... args) { + return mixin_builder()->CreatePointerCast(std::forward(args)...); + } + + template + llvm::Value* PtrToInt(Args&&... args) { + return mixin_builder()->CreatePtrToInt(std::forward(args)...); + } + + template + llvm::Value* SDiv(Args&&... args) { + return mixin_builder()->CreateSDiv(std::forward(args)...); + } + + template + llvm::Value* Select(Args&&... args) { + return mixin_builder()->CreateSelect(std::forward(args)...); + } + + template + llvm::Value* SRem(Args&&... args) { + return mixin_builder()->CreateSRem(std::forward(args)...); + } + + template + llvm::StoreInst* Store(Args&&... args) { + return mixin_builder()->CreateStore(std::forward(args)...); + } + + template + llvm::Value* UDiv(Args&&... args) { + return mixin_builder()->CreateUDiv(std::forward(args)...); + } + + template + llvm::Value* URem(Args&&... args) { + return mixin_builder()->CreateURem(std::forward(args)...); + } + + template + llvm::Value* VectorSplat(Args&&... args) { + return mixin_builder()->CreateVectorSplat(std::forward(args)...); + } + + template + llvm::Value* ZExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateZExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* AShr(Args&&... args) { + return mixin_builder()->CreateAShr(std::forward(args)...); + } + + template + llvm::Value* FCmpOEQ(Args&&... args) { + return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); + } + + template + llvm::Value* FCmpOLT(Args&&... args) { + return mixin_builder()->CreateFCmpOLT(std::forward(args)...); + } + + template + llvm::Value* FCmpONE(Args&&... args) { + return mixin_builder()->CreateFCmpONE(std::forward(args)...); + } + + template + llvm::Value* FCmpUNE(Args&&... args) { + return mixin_builder()->CreateFCmpUNE(std::forward(args)...); + } + + template + llvm::Value* FDiv(Args&&... args) { + return mixin_builder()->CreateFDiv(std::forward(args)...); + } + + template + llvm::Value* FNeg(Args&&... args) { + return mixin_builder()->CreateFNeg(std::forward(args)...); + } + + template + llvm::Value* FPCast(Args&&... args) { + return mixin_builder()->CreateFPCast(std::forward(args)...); + } + + template + llvm::Value* FPToSI(Args&&... args) { + return mixin_builder()->CreateFPToSI(std::forward(args)...); + } + + template + llvm::Value* FPToUI(Args&&... args) { + return mixin_builder()->CreateFPToUI(std::forward(args)...); + } + + template + llvm::Value* FPTrunc(Args&&... args) { + return mixin_builder()->CreateFPTrunc(std::forward(args)...); + } + + template + llvm::Value* FRem(Args&&... args) { + return mixin_builder()->CreateFRem(std::forward(args)...); + } + + template + llvm::Value* FSub(Args&&... args) { + return mixin_builder()->CreateFSub(std::forward(args)...); + } + + template + llvm::Value* ICmpSGE(Args&&... args) { + return mixin_builder()->CreateICmpSGE(std::forward(args)...); + } + + template + llvm::Value* ICmpSLT(Args&&... args) { + return mixin_builder()->CreateICmpSLT(std::forward(args)...); + } + + template + llvm::Value* IntCast(Args&&... args) { + return mixin_builder()->CreateIntCast(std::forward(args)...); + } + + template + llvm::Value* LShr(Args&&... args) { + return mixin_builder()->CreateLShr(std::forward(args)...); + } + + template + llvm::Value* MemSet(Args&&... args) { + return mixin_builder()->CreateMemSet(std::forward(args)...); + } + + template + llvm::Value* Neg(Args&&... args) { + return mixin_builder()->CreateNeg(std::forward(args)...); + } + + template + llvm::Value* Not(Args&&... args) { + return mixin_builder()->CreateNot(std::forward(args)...); + } + + template + llvm::PHINode* PHI(Args&&... args) { + return mixin_builder()->CreatePHI(std::forward(args)...); + } + + template + llvm::Value* RetVoid(Args&&... args) { + return mixin_builder()->CreateRetVoid(std::forward(args)...); + } + + template + llvm::Value* SExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateSExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* Shl(Args&&... args) { + return mixin_builder()->CreateShl(std::forward(args)...); + } + + template + llvm::Value* SIToFP(Args&&... args) { + return mixin_builder()->CreateSIToFP(std::forward(args)...); + } + + template + llvm::Value* Sub(Args&&... args) { + return mixin_builder()->CreateSub(std::forward(args)...); + } + + template + llvm::Value* Trunc(Args&&... args) { + return mixin_builder()->CreateTrunc(std::forward(args)...); + } + + template + llvm::Value* UIToFP(Args&&... args) { + return mixin_builder()->CreateUIToFP(std::forward(args)...); + } + + template + llvm::Value* Unreachable(Args&&... args) { + return mixin_builder()->CreateUnreachable(std::forward(args)...); + } + + template + llvm::Value* Xor(Args&&... args) { + return mixin_builder()->CreateXor(std::forward(args)...); + } + + private: + llvm::IRBuilder<>* mixin_builder() { + return static_cast(this)->builder(); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index b79567369aa532c4963e3941f6cb9844cd1476dd..bd0139f85b6a5c5dc23dad962263038451921e65 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return If(b_->CreateICmpSLT(start, end), [&]() -> Status { @@ -30,7 +30,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { @@ -56,7 +56,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::If( - tensorflow::StringPiece name, llvm::Value* condition, + absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); @@ -70,7 +70,7 @@ Status KernelSupportLibrary::If( void KernelSupportLibrary::EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, + absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index b00f903d56a83c5b76188007702470c44c55c213..43fec311f150d6054f6ad24f99db332f90ff94a3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ #include +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control @@ -49,13 +49,13 @@ class KernelSupportLibrary { // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { @@ -67,7 +67,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + Status For(absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { @@ -77,7 +77,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), @@ -99,13 +99,13 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator); - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& @@ -129,7 +129,7 @@ class KernelSupportLibrary { peel_first_iteration, for_body_generator); } - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function& @@ -140,7 +140,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return For(name, start, end, step, @@ -151,7 +151,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, step, @@ -162,8 +162,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, @@ -173,8 +172,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, llvm::ConstantInt::get(start->getType(), step), @@ -182,7 +180,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { return For(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -190,7 +188,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -203,7 +201,7 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(tensorflow::StringPiece name, llvm::Value* condition, + Status If(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() -> Status { return Status::OK(); }); @@ -222,7 +220,7 @@ class KernelSupportLibrary { IfReturnVoid("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + void IfReturnVoid(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() { }) { @@ -237,7 +235,7 @@ class KernelSupportLibrary { })); } - using ArgumentVector = tensorflow::gtl::ArraySlice; + using ArgumentVector = absl::Span; // Generates the following control flow structure: // @@ -259,13 +257,13 @@ class KernelSupportLibrary { // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, ArgumentVector arguments, + absl::string_view kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function& kernel_body_generator) { @@ -278,7 +276,7 @@ class KernelSupportLibrary { static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { @@ -296,4 +294,4 @@ class KernelSupportLibrary { }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 35b394127288d816952b48c84b193257bab0bcda..e5fbdbd51b8a9aa14decadedd1eeb3bdbf831738 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -28,7 +28,7 @@ namespace { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { +std::vector ConsecutiveSegments(absl::Span xs) { std::vector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { @@ -40,8 +40,7 @@ std::vector ConsecutiveSegments(tensorflow::gtl::ArraySlice xs) { // Merges the sequences of dimensions of the given shape which start at the // given indices `segs`. -Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, - const Shape& shape) { +Shape MergeDimensions(absl::Span segs, const Shape& shape) { std::vector dimensions; for (size_t i = 1; i <= segs.size(); ++i) { dimensions.push_back(std::accumulate( @@ -55,10 +54,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, } } // namespace -tensorflow::gtl::optional > FindTranspose021( - const Shape& a, const Shape& b) { +absl::optional > FindTranspose021(const Shape& a, + const Shape& b) { if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } std::vector perm(a.dimensions().size()); @@ -88,7 +87,7 @@ tensorflow::gtl::optional > FindTranspose021( return dims_021; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } IrArray::Index GetUnreducedOutputIndex( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index ccb9b8ba3e6b0079664f2da92ce67224e176fa1d..5ea05b3188a1c0881e4c0c41625d530aff1b1205 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -36,8 +36,8 @@ namespace llvm_ir { // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the // reduced shape of `b` or the 0-2-1 shape. -tensorflow::gtl::optional > FindTranspose021(const Shape& a, - const Shape& b); +absl::optional > FindTranspose021(const Shape& a, + const Shape& b); // Return the unreduced output index corresponding to the given reduced output // index. @@ -50,7 +50,7 @@ IrArray::Index GetUnreducedOutputIndex( // for 021 transpose. class TiledParameterInfo { public: - TiledParameterInfo(tensorflow::gtl::ArraySlice param_buffers, + TiledParameterInfo(absl::Span param_buffers, llvm::Value* y, llvm::Value* x) : param_buffers_(param_buffers), y_(y), x_(x) {} @@ -67,7 +67,7 @@ class TiledParameterInfo { private: // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr // if the parameter is not tiled. - tensorflow::gtl::ArraySlice param_buffers_; + absl::Span param_buffers_; // The y coordinate within a tile. llvm::Value* y_; // The x coordinate within a tile. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index ba7f94834c7fd04d97cec012537244323308b8ce..219a9f221fbd116cdfbaf17985e21d82aefd079d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -25,19 +26,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, +ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(std::string(prefix)), - suffix_(std::string(suffix)), + : prefix_(prefix), + suffix_(suffix), start_index_(start_index), end_index_(end_index), step_(step), @@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, - UnrollMode unroll_mode, bool prevent_vectorization) { + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, + bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); @@ -168,16 +167,16 @@ std::vector ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { return result; } -string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { +string ForLoop::GetQualifiedName(absl::string_view name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } -llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, +llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b) { return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, +std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, @@ -186,12 +185,9 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, unroll_mode, prevent_vectorization); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* stride, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -216,7 +212,7 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -227,7 +223,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -238,22 +234,22 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { std::vector dimensions(ShapeUtil::Rank(shape)); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); } IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix) { + const Shape& shape, absl::Span dimensions, + absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension))); index[dimension] = loop->GetIndVarValue(); } return index; @@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix) { + absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost // loops are added first. Add loops in major-to-minor order, and skip the // 'dimension_to_skip' dimension. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index a4fed5c8dc55d38d25031252e3960404a5bf84e6..ac3bba3c9fd6a9eb4e7822474963fcc5a394baf7 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -78,7 +78,7 @@ class ForLoop { // `unroll_mode` specifies the desired LLVM unrolling behavior for generated // loop. static std::unique_ptr EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -133,19 +133,18 @@ class ForLoop { // Allow ForLoopNest to call this private constructor. friend class ForLoopNest; - ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, + ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* b); - llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. - string GetQualifiedName(tensorflow::StringPiece name); + string GetQualifiedName(absl::string_view name); // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. @@ -182,9 +181,9 @@ class ForLoopNest { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : name_(std::string(name)), + : name_(name), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -197,14 +196,14 @@ class ForLoopNest { // been added then emit loop inside the body of the last added loop. // unroll_mode is used to emit metadata that controls LLVM unrolling. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -213,13 +212,13 @@ class ForLoopNest { // end index are constant. std::unique_ptr AddLoop( int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + int64 start_index, int64 end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -234,8 +233,7 @@ class ForLoopNest { // within the shape. One possible order for that sequence would be: // // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) - IrArray::Index AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix); + IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix); // Add a loop for each dimension in "dimensions". "suffix" is the // name suffix of the indvar and basic blocks in this new loop nest. @@ -244,8 +242,8 @@ class ForLoopNest { // size equals the rank of shape and there is a null for each // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( - const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix); + const Shape& shape, absl::Span dimensions, + absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops // are constructed in major to minor dimension layout order. No loop is @@ -256,7 +254,7 @@ class ForLoopNest { // basic blocks) constructed by this method. IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix); + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e6126881af8b8123e08a4eaa934b52a7fd378ce6..1a53c026be340ca3bec3a49b11666d6124728130 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -61,7 +61,7 @@ string AsString(const std::string& str) { return string(str.data(), str.length()); } -llvm::StringRef AsStringRef(tensorflow::StringPiece str) { +llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } @@ -83,11 +83,10 @@ string DumpModuleToString(const llvm::Module& module) { return AsString(buffer_string); } -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b) { +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); @@ -262,15 +261,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment) { return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment) { +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment) { llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -285,7 +286,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( /*Context=*/b->getContext(), @@ -294,27 +295,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, /*InsertBefore*/ insert_before); } -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; if_data.if_block = b->GetInsertBlock(); if_data.true_block = - CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); + CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-false"), b) + emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { b->SetInsertPoint(if_data.if_block); - if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), b); + if_data.after_block = + CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); } // Our basic block should now end with an unconditional branch. Remove it; @@ -413,14 +412,14 @@ string IrName(string a) { return a; } -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { +string IrName(absl::string_view a, absl::string_view b) { if (!a.empty() && !b.empty()) { - return IrName(tensorflow::strings::StrCat(a, ".", b)); + return IrName(absl::StrCat(a, ".", b)); } - return IrName(tensorflow::strings::StrCat(a, b)); + return IrName(absl::StrCat(a, b)); } -string IrName(const HloInstruction* a, tensorflow::StringPiece b) { +string IrName(const HloInstruction* a, absl::string_view b) { return IrName(a->name(), b); } @@ -556,7 +555,7 @@ std::map MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { +static string GetProcessUniqueIrFileName(absl::string_view prefix) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); @@ -584,18 +583,16 @@ Status DumpIRToDirectory(const string& directory_name, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string unique_and_safe_file_name = GetProcessUniqueIrFileName( - tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); + absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); string ir_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( directory_name, ir_file_name, DumpModuleToString(llvm_module))); @@ -607,8 +604,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module) { + absl::string_view name, llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -638,7 +634,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { + if (!absl::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 09583985342033d486d50910b6f5ca732a9a3756..f59baff263fe7184c6b0821c9dbd9eee205586a6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace llvm { @@ -47,11 +47,11 @@ namespace llvm_ir { // Convert a std::string (used by LLVM's interfaces) to string. string AsString(const std::string& str); -// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both -// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// Convert a absl::string_view to a llvm::StringRef. Note: both +// absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(tensorflow::StringPiece str); +llvm::StringRef AsStringRef(absl::string_view str); template llvm::ArrayRef AsArrayRef(const std::vector& vec) { @@ -59,7 +59,7 @@ llvm::ArrayRef AsArrayRef(const std::vector& vec) { } template -llvm::ArrayRef AsArrayRef(const tensorflow::gtl::ArraySlice& slice) { +llvm::ArrayRef AsArrayRef(const absl::Span& slice) { return llvm::ArrayRef(slice.data(), slice.size()); } @@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module); // - removing all '%'s. // string IrName(string a); -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b); -string IrName(const HloInstruction* a, tensorflow::StringPiece b = ""); +string IrName(absl::string_view a, absl::string_view b); +string IrName(const HloInstruction* a, absl::string_view b = ""); // Removes special characters from a function name. // @@ -101,11 +101,10 @@ string SanitizeFunctionName(string function_name); // intrinsics (for example, "minnum") must include a type in overloaded_types // for each overloaded type. Typically, overloaded intrinsics have only a single // overloaded type. -llvm::Value* EmitCallToIntrinsic( - llvm::Intrinsic::ID intrinsic_id, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice overloaded_types, - llvm::IRBuilder<>* b); +llvm::Value* EmitCallToIntrinsic(llvm::Intrinsic::ID intrinsic_id, + absl::Span operands, + absl::Span overloaded_types, + llvm::IRBuilder<>* b); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise @@ -164,21 +163,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // This can be useful to avoid e.g. executing an alloca every time // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment = 0); +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created @@ -210,7 +211,7 @@ struct LlvmIfData { // Currently the insertion point of the builder must be a well-formed // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, @@ -285,8 +286,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module); + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 36f5fa195224c20e30a14f72b32eb42a681bb5e9..0dc120e0b0df47f261435f490a8459b49d989b53 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -69,7 +69,7 @@ static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( } LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, + absl::Span target_arrays, llvm::IRBuilder<>* b) : body_emitter_(MakeBodyEmitterForMultiOutputFusion( target_element_generator, @@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. @@ -105,7 +105,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -122,7 +122,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, +Status LoopEmitter::EmitLoop(absl::string_view loop_name, llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index c4f5c82086ccfa233e0be118b1de10cce55a51b1..a537c00066b0a68404b142e91283510092b46e2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -53,8 +53,7 @@ class LoopEmitter { // This is used for multi-output fusion. target_element_generator must // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, - tensorflow::gtl::ArraySlice target_arrays, - llvm::IRBuilder<>* b); + absl::Span target_arrays, llvm::IRBuilder<>* b); LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; @@ -69,10 +68,10 @@ class LoopEmitter { } virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = "", + Status EmitLoop(absl::string_view loop_name = "", llvm::Type* index_type = nullptr); protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e546f5cc4ae305b40c1bdbcae090daadee11241b..944c79580c133906cd431722fd6b29e6aee5f918 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -16,6 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -29,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -42,7 +43,7 @@ namespace { void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, const IrArray::Index& compare_keys_index, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, + const absl::optional& values_array, llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) @@ -59,15 +60,39 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, SetToFirstInsertPoint(if_data.true_block, b); auto key1 = keys_array.EmitReadArrayElement(keys_index, b); auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, b); + auto compare_key1 = key1; + auto compare_key2 = key2; auto key_type = keys_array.GetShape().element_type(); + bool is_signed_comparison = true; + if (primitive_util::IsFloatingPointType(key_type)) { + // We would like a total order of floating point numbers so that the sort + // has a predictable behavior in the presence of NaNs. Rather than using + // floating point comparison, we use the following trick: + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? 0x7FFFFFFF - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + auto k = b->getInt(llvm::APInt::getSignedMaxValue( + key1->getType()->getPrimitiveSizeInBits())); + auto comparison_type = k->getType(); + auto zero = llvm::ConstantInt::get(comparison_type, 0); + auto maybe_flip = [&](llvm::Value* v) { + return b->CreateSelect(b->CreateICmp(llvm::ICmpInst::ICMP_SLT, v, zero), + b->CreateSub(k, v), v); + }; + compare_key1 = b->CreateBitCast(key1, comparison_type); + compare_key2 = b->CreateBitCast(key2, comparison_type); + compare_key1 = maybe_flip(compare_key1); + compare_key2 = maybe_flip(compare_key2); + } else if (!primitive_util::IsSignedIntegralType(key_type)) { + is_signed_comparison = false; + } auto comparison = - primitive_util::IsFloatingPointType(key_type) - // TODO(b/26783907): Figure out how to handle NaNs. - ? b->CreateFCmp(llvm::FCmpInst::FCMP_ULT, key2, key1) - : b->CreateICmp(primitive_util::IsSignedIntegralType(key_type) - ? llvm::ICmpInst::ICMP_SLT - : llvm::ICmpInst::ICMP_ULT, - key2, key1); + b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + compare_key2, compare_key1); // If key2 < key1 auto if_smaller_data = EmitIfThenElse(comparison, "is_smaller_than", b, /*emit_else=*/false); @@ -87,8 +112,8 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 8458744c6bc0e50a1c1cc8d3e66e29c7d4f74d73..527ed10374ce9482045a8459e38fd041e0e83001 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -31,8 +31,8 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 11ed6ee59f1bf8e7004b8bef7319b37ef41a304c..7d49b8d6c2c902ee38d72f72b3da9d190cc65bf0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -64,8 +64,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, } } -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { auto* store = b->CreateStore( diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index cf6bf5d0b14ba71cbed67f9a1dc728c0eef5e393..887fb613717ef780d6903a3b97bfdf4b735c4f82 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" // Utilities for emitting LLVM IR related to HLO tuples. @@ -65,8 +65,7 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(const IrArray& tuple, - tensorflow::gtl::ArraySlice operands, +void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5e02096ee501b23a7976a50f13bb7e7f3c5e2d34..0d0fb7946ae6815905491ca55652d7d0ab278a3c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -73,7 +74,7 @@ namespace { // If the parameter number is invalid for this computation, nullopt is // returned. When the return value has_value(), nullptr will never be // the held value. -tensorflow::gtl::optional ParameterMetadata( +absl::optional ParameterMetadata( const XlaComputation& computation, int parameter_number) { for (const HloComputationProto& comp : computation.proto().computations()) { if (comp.id() == computation.proto().entry_computation_id()) { @@ -81,14 +82,14 @@ tensorflow::gtl::optional ParameterMetadata( if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && instr.parameter_number() == parameter_number) { if (!instr.has_metadata()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } return &instr.metadata(); } } } } - return tensorflow::gtl::nullopt; + return absl::nullopt; } ExecutionOptions CreateExecutionOptions( @@ -140,7 +141,7 @@ ExecutionOptions CreateExecutionOptions( StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_program_shape()); @@ -149,7 +150,7 @@ StatusOr> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", + "Invalid number of arguments for computation: expected %d, got %u.", program_shape.parameters_size(), argument_layouts.size()); } @@ -158,7 +159,7 @@ StatusOr> LocalService::CompileExecutable( TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { - tensorflow::gtl::optional metadata = + absl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); auto metadata_string = [&metadata]() -> string { if (!metadata.has_value()) { @@ -167,16 +168,15 @@ StatusOr> LocalService::CompileExecutable( CHECK(metadata.value() != nullptr); const OpMetadata& m = *metadata.value(); if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); + return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line()); } return ""; }; return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); + metadata_string(), + ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shape)); } } if (build_options.result_layout() != nullptr) { @@ -214,7 +214,7 @@ StatusOr LocalService::GlobalDataToShapedBuffer( TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( - "replica_number %d out of range; must be less than num_replicas = %zu.", + "replica_number %d out of range; must be less than num_replicas = %u.", replica_number, buffers.size()); } return buffers[replica_number]; diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 8f707ea9046a00a15cac469672a7a992f20bf483..3b4f0b50832d6d2b64528ffb63eb5c7375396aec 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -48,7 +48,7 @@ class LocalService : public Service { // compiler is responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const XlaComputation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, + const absl::Span argument_layouts, const ExecutableBuildOptions& build_options); // Returns the device ordinal that corresponds to the given replica number. diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index c742d35a7bcafa66692195a513992c9cfbb39335..e1f56727bd209797c60f7b3f10c3e232992d01e0 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = tensorflow::strings::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color().value()); } - return tensorflow::strings::StrCat(instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "](#", id(), color_string, ")"); + return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index f9ba5a554740c9d4cc2643fe59d18ba76c30d03b..ceacab4ed7319527312a5a6ad715103b5bbaf40f 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -18,13 +18,13 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/int_type.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index d631fb5ee42df6525681a5cd1fe1a8241824121d..eaa09591b72ee5202e0a9d1225d92eca92904adc 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); logical_buffers_.emplace_back( - MakeUnique(instruction, index, next_buffer_id_)); + absl::make_unique(instruction, index, next_buffer_id_)); output_buffers_[std::make_pair(instruction, index)] = logical_buffers_.back().get(); diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc new file mode 100644 index 0000000000000000000000000000000000000000..8269842426e3ee15ea974098a43fe7752c7614df --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" +#include "absl/types/variant.h" +namespace xla { + +se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() { + if (HasOwnership()) { + return absl::get(mem_).AsDeviceMemoryBase(); + } else { + return absl::get(mem_); + } +} + +bool MaybeOwningDeviceMemory::HasOwnership() const { + return absl::holds_alternative(mem_); +} + +absl::optional MaybeOwningDeviceMemory::Release() { + if (!HasOwnership()) { + return {}; + } + OwningDeviceMemory result = std::move(absl::get(mem_)); + mem_ = result.AsDeviceMemoryBase(); + return absl::make_optional(std::move(result)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..82e7f1183c086437e10daea85ea99235b06cbb35 --- /dev/null +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ + +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +namespace xla { + +// MaybeOwningDeviceMemory represents either an owned or unowned device memory. +// Like std::variant. When the object goes +// output of scope, it will free the underlying memory if it owns it. +class MaybeOwningDeviceMemory { + public: + MaybeOwningDeviceMemory() = default; + explicit MaybeOwningDeviceMemory(OwningDeviceMemory owned) + : mem_(std::move(owned)) {} + explicit MaybeOwningDeviceMemory(se::DeviceMemoryBase unowned) + : mem_(unowned) {} + MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default; + ~MaybeOwningDeviceMemory() = default; + + MaybeOwningDeviceMemory& operator=(se::DeviceMemoryBase unowned) { + mem_ = unowned; + return *this; + } + + MaybeOwningDeviceMemory& operator=(OwningDeviceMemory owned) { + mem_ = std::move(owned); + return *this; + } + + MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default; + + // Fetches the underlying DeviceMemoryBase from a MaybeOwningDeviceMemory. The + // caller of this function is *not* responsible for freeing the memory. + se::DeviceMemoryBase AsDeviceMemoryBase(); + + // Release the OwningDeviceMemory without freeing it, and moves the ownership + // of the memory buffer from the object to the caller. + // + // A nullopt is returned if the HasOwnership() == false; + absl::optional Release(); + + // Returns true if the device_memory has ownership over underlying memory. + bool HasOwnership() const; + + private: + absl::variant mem_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 4166ef5baf9c891968b584a0c498005e9ae87784..b9ec31c4977be0c31dfff01a0c495902191d7d5b 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -262,7 +262,7 @@ void MultiOutputFusion::RecomputeReachability() { void MultiOutputFusion::UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip) { for (auto instr : instrs_to_update) { if (skip != nullptr && skip(instr)) { diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 0019cd725417d81900974b462c3b05075ce3e893..d2c52651c4f37708906e31b7839d0c9f6f04760e 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} - tensorflow::StringPiece name() const override { - return "multi_output_fusion"; - } + absl::string_view name() const override { return "multi_output_fusion"; } // Run multi-output fusion on the given module. Returns whether the module // was changed. @@ -94,7 +92,7 @@ class MultiOutputFusion : public HloPassInterface { // Update the reachability map after fusing instr1 and instr2. void UpdateReachability( HloInstruction* instr1, HloInstruction* instr2, - tensorflow::gtl::ArraySlice instrs_to_update, + absl::Span instrs_to_update, const std::function& skip = nullptr); // Hook for multi-output fusion along producer-consumer edges. @@ -104,17 +102,17 @@ class MultiOutputFusion : public HloPassInterface { // InstructionFusion instead. virtual bool DoProducerConsumerMultiOutputFusion(); - private: - // Update the internal data structures after instr1 and instr2 are fused into - // one fusion instruction. - void Update(HloInstruction* instr1, HloInstruction* instr2); - // Optimization fuel is a compiler debugging technique that makes an // optimization pass stop what it is doing after having made N changes to the // program, where N is the fuel. By varying N, this can be used to find the // first single change that makes a test fail. int64 fuel_; + private: + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + // Computation for the pass. HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c..bd8fb17a235ea6eeb0e1809e8cb9ad83145fd8d6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -52,8 +53,8 @@ NameUniquer::NameUniquer(const string& separator) { return result; } -string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); +string NameUniquer::GetUniqueName(absl::string_view prefix) { + string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. @@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); + } else { + // absl::SimpleAtoi may modify numeric_suffix even if it returns false. + numeric_suffix = 0; } } SequentialIdGenerator& id_generator = generated_names_[root]; numeric_suffix = id_generator.RegisterId(numeric_suffix); if (numeric_suffix == 0) { - return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) - : root; + return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root; } - tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + absl::StrAppend(&root, separator_, numeric_suffix); return root; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4423d6106920eaeab830bd9dc08529ff409a5161..6dd89c240f81c9f0ccac66e50c7f244bfd5429f1 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -38,7 +38,7 @@ class NameUniquer { // Get a sanitized unique name in a string, with an optional prefix for // convenience. - string GetUniqueName(tensorflow::StringPiece prefix = ""); + string GetUniqueName(absl::string_view prefix = ""); // Sanitizes and returns the name. Unallowed characters will be replaced with // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ac6ea4c72f61a47726b3ae7dd000837d3fba1b93..4869db79e719fa10d61ad6c6ed41ff70a344f733 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -622,7 +622,7 @@ template class HloInstructionPatternNameImpl { public: explicit HloInstructionPatternNameImpl(const Previous& previous, - tensorflow::StringPiece name) + absl::string_view name) : previous_(previous), name_(name) {} bool Match(const ::xla::HloInstruction* inst) const { @@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl { private: Previous previous_; - tensorflow::StringPiece name_; + absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -784,7 +784,7 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction has the given name. HloInstructionPattern> - WithName(tensorflow::StringPiece name) const { + WithName(absl::string_view name) const { return HloInstructionPattern>( HloInstructionPatternNameImpl(impl_, name), matched_inst_); @@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) { } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Iota) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 39fe3c7835d1c74c0f1e5bc0ebf5916ec734c24a..178a78ede09c34e71566fdee69793fdb1cda6245 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -19,20 +19,19 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -using tensorflow::str_util::Lowercase; - // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; @@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { string CanonicalPlatformName(const string& name) { - string platform_str = Lowercase(name); + string platform_str = absl::AsciiStrToLower(name); // "cpu" and "host" mean the same thing. if (platform_str == "cpu") { platform_str = "host"; @@ -90,41 +89,54 @@ PlatformUtil::GetSupportedPlatforms() { if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + se::Platform* platform = platforms[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", - platforms_string.c_str()); + platforms_string); } /* static */ StatusOr PlatformUtil::GetDefaultPlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + + se::Platform* platform = nullptr; if (platforms.empty()) { return NotFound("no platforms found"); } else if (platforms.size() == 1) { - return platforms[0]; + platform = platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { - if (Lowercase(platforms[i]->Name()) == kInterpreter && - Lowercase(platforms[1 - i]->Name()) != kInterpreter) { - return platforms[1 - i]; + if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && + absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { + platform = platforms[1 - i]; + break; } } } + if (platform != nullptr) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; + } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform (except for the " "interpreter platform) found: %s", - platforms_string.c_str()); + platforms_string); } /*static*/ StatusOr PlatformUtil::GetPlatform( @@ -132,11 +144,14 @@ PlatformUtil::GetSupportedPlatforms() { string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) == platform_str) { + if (absl::AsciiStrToLower(platform->Name()) == platform_str) { + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } return platform; } } - return InvalidArgument("platform %s not found", platform_name.c_str()); + return InvalidArgument("platform %s not found", platform_name); } /*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( @@ -146,23 +161,27 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); std::vector matched; for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) != platform_name) { + if (absl::AsciiStrToLower(platform->Name()) != platform_name) { matched.push_back(platform); } } if (matched.empty()) { return InvalidArgument("unable to find platform that is not %s", - platform_name.c_str()); + platform_name); } if (matched.size() == 1) { - return matched[0]; + auto platform = matched[0]; + if (!platform->Initialized()) { + TF_RETURN_IF_ERROR(platform->Initialize({})); + } + return platform; } - string matched_string = tensorflow::str_util::Join( + string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "found multiple platforms %s, but expected one platform except for %s", - matched_string.c_str(), platform_name.c_str()); + matched_string, platform_name); } // Returns whether the device underlying the given StreamExecutor is supported @@ -193,7 +212,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { PlatformUtil::GetStreamExecutors(se::Platform* platform) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { - return NotFound("no %s devices found", platform->Name().c_str()); + return NotFound("no %s devices found", platform->Name()); } if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware @@ -232,7 +251,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", - platform->Name().c_str()); + platform->Name()); } return stream_executors; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index afde3cf95c721b59a39b74b4e1ff3f15a335fe97..256b231e3af43a2ee85c97a5efab1f022d4de4b1 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface { ~ReducePrecisionInsertion() override{}; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "reduce-precision-insertion"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index ca86c5d13e98a98c62d0c9e8e32e28fe99e0fa1f..4df746fca9f8320eed72911726f33bb01f06fed5 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include + +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -374,7 +376,7 @@ StatusOr TryReshapeMoveOnCandidates( removed = false; for (auto operand : nontrivial_operands) { - if (c_any_of(operand->users(), [&](HloInstruction* user) { + if (absl::c_any_of(operand->users(), [&](HloInstruction* user) { return !reshape_candidates->count(user); })) { for (auto* user : operand->users()) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1f59e3b3147facb6f2fae00d6c810bf54d560e5c..1e86a0823a56a9e52421a5c8bd49e0adb98a2c70 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-mover"; } + absl::string_view name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ccb9fb3e3af5e308accc924d3501213841d7d6c7..fcf269eee925c2ddb7511d70e71bd815e4b8c24a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -28,13 +28,13 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloVerifiedTestBase; + +namespace op = xla::testing::opcode_matchers; + +class ReshapeMoverTest : public HloVerifiedTestBase {}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 45ca731153bf4312b1c78b3c74224b1ec7ed8436..2f4b2667c405bb23b1c648892c86d337400c14a5 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -25,7 +26,6 @@ limitations under the License. namespace xla { -using tensorflow::gtl::ArraySlice; // Transposes the given scatter_indices such that the index_vector_dim becomes // the most-minor dimension. @@ -86,13 +86,13 @@ static StatusOr CanonicalizeScatterIndices( // major dimensions and all the window dimensions appear in the minor // dimensions. static StatusOr PermuteScatterAndWindowDims( - HloInstruction* updates, ArraySlice update_window_dims) { + HloInstruction* updates, absl::Span update_window_dims) { std::vector permutation; const int64 updates_rank = ShapeUtil::Rank(updates->shape()); permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { - bool is_scatter_dim = !c_binary_search(update_window_dims, i); + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); if (is_scatter_dim) { permutation.push_back(i); } @@ -290,7 +290,7 @@ StatusOr ScatterExpander::ExpandScatter( return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " "supported. This error occurred for %s.", - scatter->ToString().c_str()); + scatter->ToString()); } // Canonicalize the scatter_indices, after which the size of its most-major diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 8f735e877d270c10b494e1cd974904c4e2d960c9..14f062c89cfd4657097c1a933621a3e945f89c53 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -22,7 +22,7 @@ namespace xla { class ScatterExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "scatter_expander"; } + absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 1dbf540d13d1fb6f6a4052caeff922cc0290f1b8..f0e2566a3f9ef5c0be8af46d3a16cd9c72793366 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -20,10 +20,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -46,8 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -55,18 +55,16 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + // Records the arguments used to invoke a computation in an HloSnapshot proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::Stream* stream, TransferManager* transfer_manager, - HloSnapshot* module) { +Status RecordArguments(const absl::Span arguments, + se::Stream* stream, TransferManager* transfer_manager, + HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( @@ -148,19 +146,19 @@ Service::Service(const ServiceOptions& options, CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) << "Requested more replicas than there are devices."; } - LOG(INFO) << Printf( + LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); + execute_backend_->platform()->Name()); for (int i = 0; i < execute_backend_->device_count(); ++i) { if (execute_backend_->device_ordinal_supported(i)) { se::StreamExecutor* executor = execute_backend_->stream_executor(i).ValueOrDie(); const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); } } } else { @@ -200,16 +198,16 @@ Status Service::ValidateResultShape(const Shape& client_shape, return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(client_shape).c_str(), - ShapeUtil::HumanString(result_shape).c_str()); + ShapeUtil::HumanStringWithLayout(client_shape), + ShapeUtil::HumanString(result_shape)); } return Status::OK(); } StatusOr>> Service::ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors) { + absl::Span arguments, + absl::Span stream_executors) { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -231,9 +229,9 @@ Service::ResolveAndValidateArguments( return InvalidArgument( "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, shaped_buffer->platform()->Name().c_str(), + i, shaped_buffer->platform()->Name(), shaped_buffer->device_ordinal(), - execute_backend_->device_name(replica_device_ordinal).c_str()); + execute_backend_->device_name(replica_device_ordinal)); } replicated_arguments[replica].push_back(shaped_buffer); } @@ -243,13 +241,13 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options) { - auto config = MakeUnique(program_shape); + auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { - return InvalidArgument("computation takes %d parameters, but %zu given", + return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } @@ -261,8 +259,8 @@ StatusOr> Service::CreateModuleConfig( return InvalidArgument( "Argument does not match shape of computation parameter %d: want " "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(*argument_shapes[i])); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -300,7 +298,7 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { @@ -314,7 +312,7 @@ StatusOr>> Service::BuildExecutables( std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); + VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector> hlo_snapshots; @@ -326,12 +324,11 @@ StatusOr>> Service::BuildExecutables( if (directory_path.empty() && execution_directory_path.empty()) { continue; } - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = - Printf("computation_%lld__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -369,12 +366,10 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile) { + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector streams; @@ -409,7 +404,8 @@ Service::ExecuteParallelAndRegisterResult( streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.push_back(MakeUnique(streams.back()->parent())); + timers.push_back( + absl::make_unique(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -453,8 +449,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < streams.size(); ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %lld: %s", - i, block_status.error_message().c_str()); + return InternalError("failed to complete execution for stream %d: %s", i, + block_status.error_message()); } } @@ -512,8 +508,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile) { // Set up streams. std::vector streams; @@ -556,8 +551,7 @@ StatusOr Service::ExecuteAndRegisterResult( // TODO(b/69985541): Support profiling also on this path. - std::vector> - replicated_arguments; + std::vector> replicated_arguments; for (const auto& arg : arguments) { replicated_arguments.push_back(arg); } @@ -579,7 +573,7 @@ StatusOr> Service::GetExecutors( if (requests_size > 1 && execution_options.device_handles_size() > 1) { return InvalidArgument( "Parallel requests with multiple device handles is not supported. " - "Found %lld parallel requests, with request %lld containing %d device " + "Found %d parallel requests, with request %d containing %d device " "handles.", requests_size, request_index, execution_options.device_handles_size()); } @@ -596,7 +590,7 @@ StatusOr> Service::GetExecutors( StatusOr>> Service::GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. // In the case of partitioned computations, assume all arguments go on the @@ -744,8 +738,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%lld) exceeds the number of available devices " - "on the target (%lld)", + "Requested device count (%d) exceeds the number of available devices " + "on the target (%d)", arg->device_count(), available_device_count); } @@ -795,12 +789,12 @@ StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf( + VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, - module_proto.name().c_str()); + module_proto.name()); // Dump computation proto state if flag is set. - auto hlo_snapshot = MakeUnique(); + auto hlo_snapshot = absl::make_unique(); const string& directory_path = module_config->debug_options().xla_dump_computations_to(); const string& execution_directory_path = @@ -808,8 +802,8 @@ StatusOr> Service::BuildExecutable( if (!directory_path.empty() || !execution_directory_path.empty()) { *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s", module_proto.id(), - module_proto.entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_proto.id(), + module_proto.entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -954,7 +948,7 @@ namespace { // shape and DeviceMemoryBase values of the clone are identical to the original. std::unique_ptr CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = MakeUnique( + auto clone = absl::make_unique( shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), shaped_buffer.platform(), device_ordinal); clone->buffers() = shaped_buffer.buffers(); @@ -1009,8 +1003,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, "%s", StrCat("The replica_id=", arg->replica_id(), " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") - .c_str()); + replica_count, ").")); } se::StreamExecutor* executor; @@ -1036,8 +1029,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( - "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " - "%lld)", + "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", arg->replica_id(), replica_count); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47d196fb2aaee897ce1fd3745129af10bf5b2d2d..44c5248b150cff57546d3287869787f37c8975ba 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/allocation_tracker.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -176,7 +176,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. @@ -191,7 +191,7 @@ class Service : public ServiceInterface { // Prepare the arguments for executing parallel. StatusOr>> GetArguments( const ExecutionOptions& execution_options, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); protected: friend class LocalExecutable; @@ -207,14 +207,14 @@ class Service : public ServiceInterface { // the corresponding replica. StatusOr>> ResolveAndValidateArguments( - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice stream_executors); + absl::Span arguments, + absl::Span stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice argument_shapes, + absl::Span argument_shapes, const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. @@ -242,21 +242,17 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice> - arguments, + const absl::Span> arguments, Backend* backend, const string& result_tag, ExecutionProfile* profile); // Runs the given executables with the given arguments and register the result // from each executable in the allocation tracker. The handles of the result // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( - tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice>> - arguments, - Backend* backend, - tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags, - ExecutionProfile* profile); + absl::Span executables, + absl::Span>> arguments, + Backend* backend, absl::Span device_handles, + absl::Span result_tags, ExecutionProfile* profile); // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a4ea2b28f4dbf41d61702f1af2d65c4d2c86d578..26117498621450d56259507761b6b0a6ea8d3a15 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -21,6 +21,11 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -28,44 +33,37 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -using tensorflow::str_util::Join; -using tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrFormat; +using absl::StrJoin; + // Returns true if no element is present in slice more than once. -bool AllUnique(tensorflow::gtl::ArraySlice slice) { +bool AllUnique(absl::Span slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { +Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); + string(op_type), ShapeUtil::HumanString(shape)); } return Status::OK(); } -Status VerifyReducerShape( - const ProgramShape& reducer_shape, - tensorflow::gtl::ArraySlice init_value_shapes, - tensorflow::gtl::ArraySlice input_element_types, - int64 inputs) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + absl::Span init_value_shapes, + absl::Span input_element_types, + int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take %lld parameters, but " + "Reduction function must take %d parameters, but " "takes %d parameter(s).", inputs * 2, reducer_shape.parameters_size()); } @@ -75,7 +73,7 @@ Status VerifyReducerShape( if (ShapeUtil::IsArray(accumulator_shape)) { if (inputs != 1) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but " + "Reduction function must produce a tuple with %d elements, but " "produces a scalar", inputs); } @@ -83,8 +81,8 @@ Status VerifyReducerShape( } else if (ShapeUtil::IsTuple(accumulator_shape)) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but has " - "%lld elements", + "Reduction function must produce a tuple with %d elements, but has " + "%d elements", inputs, ShapeUtil::TupleElementCount(accumulator_shape)); } for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { @@ -94,7 +92,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must produce a scalar or tuple of scalars, but has " "shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } for (const Shape* element_shape : accumulator_subshapes) { @@ -102,7 +100,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } } @@ -113,19 +111,19 @@ Status VerifyReducerShape( if (!ShapeUtil::Compatible(*accumulator_subshapes[i], reducer_shape.parameters(i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "result shape: %s vs %s", - i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + i, ShapeUtil::HumanString(reducer_shape.parameters(i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } // Check that init_value's shapes are suitable for reducer_shape. if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], *init_value_shapes[i])) { return InvalidArgument( - "Reduction function's accumulator shape at index %lld differs from " + "Reduction function's accumulator shape at index %d differs from " "the init_value shape: %s vs %s", - i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), - ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + i, ShapeUtil::HumanString(*accumulator_subshapes[i]), + ShapeUtil::HumanString(*init_value_shapes[i])); } // Check that the inputs can be passed in as the non-accumulator arguments. const Shape input_element_shape = @@ -133,11 +131,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( input_element_shape, reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "input type element type: %s vs %s", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(input_element_shape)); } // Check that the accumulator and inputs to the reducer function match. // If the accumulator is scalar, it must have the same type as the inputs @@ -147,11 +145,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape must " + "Reduction function's %d-th parameter shape must " "match the result shape, but got %s vs %s.", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } } @@ -164,7 +162,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, bool allow_negative_padding) { if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { return InvalidArgument( - "Window has dimension %d but base shape has dimension %lld.", + "Window has dimension %d but base shape has dimension %d.", window.dimensions_size(), ShapeUtil::Rank(base_shape)); } @@ -173,29 +171,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const auto& dim = window.dimensions(i); if (dim.size() <= 0) { return InvalidArgument("Window %s has a non-positive dimension.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.stride() <= 0) { return InvalidArgument("Window %s has a non-positive stride.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_low() < 0) { return InvalidArgument("Window %s has a negative low padding.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_high() < 0) { return InvalidArgument("Window %s has a negative high padding.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.base_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive base area dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.window_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive window dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } const int64 dilated_base = window_util::DilatedBound( @@ -233,11 +231,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kRoundNearestAfz: if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating for floor/ceil " - "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating for %s operation; " + "got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kCos: @@ -250,9 +249,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "Expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kReal: @@ -264,19 +263,47 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } else { return InvalidArgument( "Expected element type in shape to be floating or complex for " - "real/imag operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( shape, primitive_util::ComplexComponentType(shape.element_type())); + } else if (ShapeUtil::ElementIsSigned(shape)) { + return shape; + } else { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } - return shape; case HloOpcode::kClz: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to Clz " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kNegate: - case HloOpcode::kRoundNearestAfz: + if (!ShapeUtil::ElementIsIntegral(shape) && + !ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be integral, floating or " + "complex for %s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: + if (!ShapeUtil::ElementIsSigned(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be signed or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } return shape; case HloOpcode::kNot: @@ -285,7 +312,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; @@ -295,25 +322,24 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, "Expected element type in shape to be floating " "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, - const int64 dimension) { + absl::Span arg_shapes, const int64 dimension) { if (arg_shapes.empty()) { return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("Concatenate dimension out of bounds: %lld.", + return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } const Shape* arg_shape = nullptr; @@ -327,17 +353,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), - ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), - ShapeUtil::HumanString(*shape).c_str()); + ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), + ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "Cannot concatenate arrays with different element types: %s vs %s.", - PrimitiveType_Name(arg_shape->element_type()).c_str(), - PrimitiveType_Name(shape->element_type()).c_str()); + PrimitiveType_Name(arg_shape->element_type()), + PrimitiveType_Name(shape->element_type())); } for (int64 dimension_number = 0; dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { @@ -350,9 +375,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld.", - ShapeUtil::HumanString(*arg_shape).c_str(), - ShapeUtil::HumanString(*shape).c_str(), dimension); + "the same): %s vs %s in dimension %d.", + ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape), + dimension); } } element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); @@ -367,7 +392,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } /* static */ StatusOr ShapeInference::InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes) { + absl::Span arg_shapes) { for (const Shape* arg_shape : arg_shapes) { if (arg_shape->element_type() != TOKEN) { return InvalidArgument( @@ -384,8 +409,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( "Conversion from complex to real type %s => %s is not implemented.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -394,8 +419,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Convert does not allow non-arrays, so cannot convert from %s to %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -407,8 +432,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { return InvalidArgument("Conversion from complex to real type %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -417,15 +442,15 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Cannot convert from or to tuple type; requested conversion: %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( "Cannot bitcast types with different bit-widths: %s => %s.", - PrimitiveType_Name(old_element_type).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + PrimitiveType_Name(old_element_type), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -438,7 +463,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating point for " "ReducePrecision operation; got %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having @@ -470,21 +495,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - padding_config.ShortDebugString().c_str()); + ShapeUtil::HumanString(operand_shape), + padding_config.ShortDebugString()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( "The element types of the operands to Pad do not match."); } + if (absl::c_any_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& p) { + return p.interior_padding() < 0; + })) { + return InvalidArgument("Interior padding cannot be negative: %s", + padding_config.ShortDebugString()); + } + std::vector dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { - dimensions[i] = operand_shape.dimensions(i) + - padding_config.dimensions(i).edge_padding_low() + - padding_config.dimensions(i).edge_padding_high() + + const auto& p = padding_config.dimensions(i); + dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + + p.edge_padding_high() + std::max(operand_shape.dimensions(i) - 1, 0LL) * - padding_config.dimensions(i).interior_padding(); + p.interior_padding(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), @@ -515,22 +548,22 @@ Status ValidateDotDimensionNumbers( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { // Check that dimension numbers are in range. - auto dims_in_range = - [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_in_range = [](const int64 rank, + absl::Span contracting_dims, + absl::Span batch_dims) -> bool { auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; return std::all_of(contracting_dims.begin(), contracting_dims.end(), in_range) && std::all_of(batch_dims.begin(), batch_dims.end(), in_range); }; - tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + absl::Span lhs_contracting_dimensions = AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + absl::Span rhs_contracting_dimensions = AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); - tensorflow::gtl::ArraySlice lhs_batch_dimensions = + absl::Span lhs_batch_dimensions = AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); - tensorflow::gtl::ArraySlice rhs_batch_dimensions = + absl::Span rhs_batch_dimensions = AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, @@ -538,12 +571,12 @@ Status ValidateDotDimensionNumbers( !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that dimension numbers are unique. - auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, - tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto dims_unique = [](absl::Span contracting_dims, + absl::Span batch_dims) -> bool { tensorflow::gtl::FlatSet dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; @@ -556,7 +589,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is not unique in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. @@ -601,14 +634,13 @@ Status ValidateDotDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { - string message = tensorflow::strings::Printf( - "Cannot infer shape for dot operation: %s %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + string message = + StrFormat("Cannot infer shape for dot operation: %s %s.", + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); if (!addendum.empty()) { message += " " + addendum; } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); }; // Check if both element types are the same. @@ -704,9 +736,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -715,20 +746,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. return InvalidArgument("Automatic shape inference not supported: %s and %s", - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " - " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu.", + " lower-rank operand's rank is %d, size of broadcast_dimensions is " + "%u.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -778,12 +809,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "Broadcast dimension number (%lld) cannot be negative.", + "Broadcast dimension number (%d) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "Broadcast dimension number (%lld) too large; higher-rank " + "Broadcast dimension number (%d) too large; higher-rank " "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } @@ -795,16 +826,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, + "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "Broadcast dimensions order is wrong: %lld comes after %lld.", + "Broadcast dimensions order is wrong: %d comes after %d.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -816,15 +847,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { + absl::Span broadcast_dimensions) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -873,21 +904,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - VLOG(2) << tensorflow::strings::Printf( + absl::Span broadcast_dimensions) { + VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), - Join(broadcast_dimensions, ", ").c_str()); + HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR( - ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR( - ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -909,7 +937,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected element type in shape to be floating for complex compose " "operation; got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, @@ -928,7 +956,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected pred or integral type in argument to and/or operation; " "got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -946,8 +974,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode), lhs.ShortDebugString(), + rhs.ShortDebugString()); } } @@ -970,14 +998,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kTupleSelect: return InferTupleSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands) { + HloOpcode opcode, absl::Span operands) { std::vector operand_shapes; operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { @@ -987,8 +1013,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes) { + HloOpcode opcode, absl::Span operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } @@ -1010,8 +1035,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Sort keys and values dimensions must match. " "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), - ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), + ShapeUtil::HumanString(*operand_shapes[1])); } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); @@ -1019,15 +1044,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Unexpected number of operands for sort"); } default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } /* static */ StatusOr ShapeInference::InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions) { + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument."); } @@ -1058,7 +1081,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - Join(pieces, ", ").c_str()); + StrJoin(pieces, ", ")); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1066,7 +1089,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", + "arg_dimension_size: %d, requested_map_dimensions_size: %u.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1075,7 +1098,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - Join(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ")); } } @@ -1083,7 +1106,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu.", + "arity: %d, arguments: %u.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1092,7 +1115,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( "Mapped computation's result has to be a scalar; got: %s.", - ShapeUtil::HumanString(output_shape).c_str()); + ShapeUtil::HumanString(output_shape)); } for (int i = 0; i < to_apply.parameters_size(); ++i) { @@ -1102,7 +1125,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter has to be a scalar; " "got parameter %d shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, @@ -1110,8 +1133,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str(), - ShapeUtil::HumanString(*arg_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape), + ShapeUtil::HumanString(*arg_shape)); } } @@ -1140,35 +1163,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld.", + "batch-norm-training to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1176,7 +1199,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-training must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1185,8 +1208,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1195,8 +1218,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1206,16 +1229,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1250,35 +1273,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld.", + "batch-norm-inference to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1286,7 +1309,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-inference must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1296,8 +1319,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1307,8 +1330,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1318,8 +1341,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of mean is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, @@ -1329,8 +1352,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of variance is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(variance_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(variance_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1340,32 +1363,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of mean is %lld " - "and the feature count is %lld.", + "but the size of mean is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1395,36 +1418,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" - " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld.", + " output_grad_shape; got rank(oprand_shape) %d, and" + " rank(output_grad_shape) %d.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } if (ShapeUtil::Rank(mean_shape) != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(mean_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } if (ShapeUtil::Rank(var_shape) != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(var_shape)); } @@ -1432,14 +1455,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, @@ -1448,8 +1471,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " "and the element type of operand is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1458,8 +1481,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " "and the element type of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1468,8 +1491,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, @@ -1478,8 +1501,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1490,24 +1513,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1517,8 +1540,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld.", + "but the bound of operand_shape at dimension %d is %d " + "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1530,22 +1553,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums) { + const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" "Window: %s", - window.DebugString().c_str()); + window.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1553,19 +1575,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" "Window: %s\nDimension numbers: %s.", - window.DebugString().c_str(), dnums.DebugString().c_str()); + window.DebugString(), dnums.DebugString()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1602,26 +1624,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } std::vector input_spatial_dims(num_spatial_dims); @@ -1640,14 +1662,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features) { + if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( - "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension (value %lld); got (%s, %s)\n" + "Expected LHS feature dimension (value %d) to match RHS " + "input feature dimension * feature_group_count (value %d); " + "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + input_features, kernel_input_features * feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1659,8 +1682,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "RHS shape: %s\n\t" "Window: {%s}\n\t" "Dimension numbers: {%s}.", - ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), - dnums.ShortDebugString().c_str()); + ShapeUtil::HumanString(rhs), window.ShortDebugString(), + dnums.ShortDebugString()); } Shape base_shape = @@ -1683,32 +1706,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferFftShape( const Shape& in, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { + const absl::Span fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); } -#define RET_CHECK_RANK(x) \ - if (x.dimensions_size() < fft_rank) { \ - return InvalidArgument( \ - "FFT of rank %lld requires input of at least " \ - "same rank; got input of rank %d", \ - fft_rank, x.dimensions_size()); \ +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %d requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ } switch (fft_type) { case FFT: case IFFT: if (in.element_type() != C64) { return InvalidArgument("%s requires C64 input type, found %s.", - FftType_Name(fft_type).c_str(), - PrimitiveType_Name(in.element_type()).c_str()); + FftType_Name(fft_type), + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { @@ -1716,7 +1739,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1730,7 +1753,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case IRFFT: { if (in.element_type() != C64) { return InvalidArgument("IRFFT requires C64 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); Shape result = ShapeUtil::ComplexComponentShape(in); @@ -1739,7 +1762,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld.", + "fft_length, but dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1749,7 +1772,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1764,7 +1787,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( ExpectArray(*operand_shape, "operand of cross replica sum")); @@ -1785,18 +1808,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(split_count > 0); if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { return InvalidArgument( - "AllToAll split_dimension %lld is out-of-bounds in shape %s.", - split_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll split_dimension %d is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape)); } if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { return InvalidArgument( - "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", - concat_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll concat_dimension %d is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape)); } if (shape.dimensions(split_dimension) % split_count != 0) { return InvalidArgument( - "AllToAll split dimension size %lld must be dividable by split_count " - "%lld.", + "AllToAll split dimension size %d must be dividable by split_count " + "%d.", shape.dimensions(split_dimension), split_count); } std::vector new_dimensions(shape.dimensions().begin(), @@ -1807,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes) { + absl::Span operand_shapes) { // An Alltoall HLO instruction receives N operands (with the same shape) and // returns a tuple that contains N array shapes. TF_RET_CHECK(!operand_shapes.empty()); @@ -1816,17 +1839,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "HLO all-to-all has operands with different shapes: the 0th " "operand shape %s, but the %dth operand has shape %s.", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, - ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), i, + ShapeUtil::HumanString(*operand_shapes[i])); } } return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr ShapeInference::InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply) { if (arg_shapes.empty()) { return InvalidArgument("Reduce must have at least 2 arguments, has 0"); @@ -1838,17 +1867,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } int64 num_reduced_args = arg_shapes.size() / 2; - tensorflow::gtl::ArraySlice reduced_args(arg_shapes, 0, - num_reduced_args); + auto reduced_args = arg_shapes.subspan(0, num_reduced_args); // Check that all of the reduced tensors have the same dimensions. The element // types may be different. for (int64 i = 1; i < num_reduced_args; ++i) { if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( "All reduced tensors must have the sime dimension. Tensor 0 has " - "shape %s, Tensor %lld has shape %s", - ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, - ShapeUtil::HumanString(*reduced_args[i]).c_str()); + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*reduced_args[0]), i, + ShapeUtil::HumanString(*reduced_args[i])); } } @@ -1858,14 +1886,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { - return InvalidArgument( - "Reducing out-of-bounds dimension %lld in shape %s.", dimension, - ShapeUtil::HumanString(arg).c_str()); + return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", + dimension, ShapeUtil::HumanString(arg)); } } - tensorflow::gtl::ArraySlice init_values( - arg_shapes, num_reduced_args, arg_shapes.size()); + auto init_values = arg_shapes.subspan(num_reduced_args, arg_shapes.size()); std::vector element_types; for (const Shape* arg : reduced_args) { element_types.push_back(arg->element_type()); @@ -1933,16 +1959,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select function's first parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(0)), + ShapeUtil::HumanString(operand_element_shape)); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( "Select function's second parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(1)), + ShapeUtil::HumanString(operand_element_shape)); } // Check if the scatter function has a proper shape as a reduction. @@ -1960,43 +1986,40 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s).", - ShapeUtil::HumanString(source_shape).c_str(), - ShapeUtil::HumanString(window_result_shape).c_str()); + ShapeUtil::HumanString(source_shape), + ShapeUtil::HumanString(window_result_shape)); } return operand_shape; } /* static */ StatusOr ShapeInference::InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides) { + const Shape& arg, absl::Span starts, + absl::Span limits, absl::Span strides) { auto error = [&](const string& message) { return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", - message.c_str(), ShapeUtil::HumanString(arg).c_str(), - Join(starts, ",").c_str(), Join(limits, ",").c_str(), - Join(strides, ",").c_str()); + message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), + StrJoin(limits, ","), StrJoin(strides, ",")); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); - VLOG(2) << tensorflow::strings::Printf( - "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), - Join(limits, ", ").c_str()); + VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg), StrJoin(starts, ", "), + StrJoin(limits, ", ")); if (starts.size() != limits.size()) { - return error(Printf("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size())); + return error(StrFormat("slice start and limit sizes differ: %u vs %u", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return error(Printf("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size())); + return error(StrFormat("slice start and strides sizes differ: %u vs %u", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "Slice index count does not match argument rank: %zu vs %lld.", + "Slice index count does not match argument rank: %u vs %d.", starts.size(), ShapeUtil::Rank(arg)); } @@ -2006,27 +2029,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("Negative start index to slice: %lld.", - start_index); + return InvalidArgument("Negative start index to slice: %d.", start_index); } if (limit_index > arg.dimensions(dimension)) { return error( - Printf("limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension))); - } - VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, - start_index); - VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, - limit_index); + StrFormat("limit index (%d) must be less than or equal to dimension " + "size (%d)", + limit_index, arg.dimensions(dimension))); + } + VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); + VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); if (start_index > limit_index) { return error( - Printf("limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index)); + StrFormat("limit index (%d) must be greater or equal to " + "start index (%d) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("Stride (%lld) must be positive.", stride); + return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2036,20 +2056,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - Join(slice_sizes, ", ").c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2061,16 +2080,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "Dynamic slice index count does not match argument rank: %zu vs %lld.", + "Dynamic slice index count does not match argument rank: %u vs %d.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2078,16 +2096,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("Negative size index to dynamic slice: %lld.", + return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "Slice dim size %lld greater than dynamic slice dimension: %lld.", + "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, - slice_dim_size); + VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); @@ -2103,16 +2120,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, "start indices of dynamic update slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " "shape %s", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::HumanString(update_shape).c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2124,17 +2141,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic update slice start number of dimensions %lld (%s) must match " - "rank %lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " - "%lld vs %lld.", + "%d vs %d.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } @@ -2143,8 +2159,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str(), - PrimitiveType_Name(update_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type()), + PrimitiveType_Name(update_shape.element_type())); } for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { @@ -2152,23 +2168,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "Size index %lld to dynamic update slice must be >= 0.", + "Size index %d to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "Update dim size %lld greater than dynamic slice dimension: %lld.", + "Update dim size %d greater than dynamic slice dimension: %d.", update_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, - update_dim_size); + VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } return operand_shape; } /*static */ StatusOr ShapeInference::InferReverseShape( - const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand_shape, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); @@ -2176,8 +2191,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", - dimension, ShapeUtil::HumanString(operand_shape).c_str()); + "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", + dimension, ShapeUtil::HumanString(operand_shape)); } } return operand_shape; @@ -2188,14 +2203,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", - ShapeUtil::HumanString(arg).c_str()); + ShapeUtil::HumanString(arg)); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "Cannot infer shape: attempt to index out of tuple bounds: %lld " + "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", - index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg)); } return arg.tuple_shapes(index); @@ -2215,17 +2230,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } auto shape_string = [&]() { - return tensorflow::strings::Printf( - "Condition: %s; body: %s; init: %s.", - ShapeUtil::HumanString(condition).c_str(), - ShapeUtil::HumanString(body).c_str(), - ShapeUtil::HumanString(init).c_str()); + return StrFormat( + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition), + ShapeUtil::HumanString(body), ShapeUtil::HumanString(init)); }; // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { return InvalidArgument("Condition must return a boolean; got %s.", - shape_string().c_str()); + shape_string()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || @@ -2233,7 +2246,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The parameter of condition and body, the result of the body, and init " "must all have the same shape; got %s.", - shape_string().c_str()); + shape_string()); } return init; @@ -2245,7 +2258,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate).c_str()); + ShapeUtil::HumanString(predicate)); } if (true_computation.parameters_size() != 1) { @@ -2254,15 +2267,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { auto true_shape_string = [&]() { - return tensorflow::strings::Printf( - "true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand).c_str(), - ShapeUtil::HumanString(true_computation).c_str()); + return StrFormat("true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand), + ShapeUtil::HumanString(true_computation)); }; return InvalidArgument( "true_operand must match the shape of the only parameter of " "true_computation: got %s.", - true_shape_string().c_str()); + true_shape_string()); } if (false_computation.parameters_size() != 1) { @@ -2271,38 +2283,37 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { auto false_shape_string = [&]() { - return tensorflow::strings::Printf( - "false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand).c_str(), - ShapeUtil::HumanString(false_computation).c_str()); + return StrFormat("false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand), + ShapeUtil::HumanString(false_computation)); }; return InvalidArgument( "false_operand must match the shape of the only parameter of " "false_computation: got %s.", - false_shape_string().c_str()); + false_shape_string()); } if (!ShapeUtil::Compatible(true_computation.result(), false_computation.result())) { auto shape_string = [&]() { - return tensorflow::strings::Printf( + return StrFormat( "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()).c_str(), - ShapeUtil::HumanString(false_computation.result()).c_str()); + ShapeUtil::HumanString(true_computation.result()), + ShapeUtil::HumanString(false_computation.result())); }; return InvalidArgument( "the result of true_computation and false_computation must have the " "same shape: got %s.", - shape_string().c_str()); + shape_string()); } return true_computation.result(); } /* static */ StatusOr ShapeInference::InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { + const Shape& operand, absl::Span broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { - return InvalidArgument("Broadcast with negative dimension size %lld.", + return InvalidArgument("Broadcast with negative dimension size %d.", size); } } @@ -2316,8 +2327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { + const Shape& operand, absl::Span dimensions, + absl::Span new_sizes) { TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = @@ -2327,11 +2338,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "Reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s).", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + ShapeUtil::HumanString(inferred_shape)); } std::vector indices(ShapeUtil::Rank(operand)); @@ -2342,14 +2353,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } return inferred_shape; } /* static */ StatusOr ShapeInference::InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { + const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); @@ -2377,9 +2388,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min).c_str(), - ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::HumanString(max).c_str()); + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || ShapeUtil::IsScalar(min)) && @@ -2396,9 +2407,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::ChangeElementType(min, operand.element_type()); } } - return Unimplemented( - "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), - max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); + return Unimplemented("%s, %s %s is not implemented.", + min.ShortDebugString(), max.ShortDebugString(), + operand.ShortDebugString()); } // TODO(b/36794510): Make broadcast semantics more consistent, by supporting @@ -2409,13 +2420,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || ShapeUtil::IsScalar(pred)) { @@ -2428,7 +2438,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select operation with non-scalar predicate with dimensionality " " different from the other operands: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } } @@ -2439,38 +2449,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::Compatible(on_true, on_false)) { return InvalidArgument( "Operands to tuple-select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "TupleSelect's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (!ShapeUtil::IsScalar(pred)) { return InvalidArgument( "TupleSelect operation with non-scalar predicate: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } return on_true; } /* static */ StatusOr ShapeInference::InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply) { + absl::Span arg_shapes, const ProgramShape& to_apply) { // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); string argument_shapes = - Join(arg_shapes, ", ", [](string* out, const Shape* shape) { - tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) { + absl::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu; computation signature: %s; argument " + "arity: %d, arguments: %u; computation signature: %s; argument " "shapes: [%s].", - to_apply.parameters_size(), arg_shapes.size(), - computation_signature.c_str(), argument_shapes.c_str()); + to_apply.parameters_size(), arg_shapes.size(), computation_signature, + argument_shapes); } // All arguments must be compatible with the program shape. @@ -2481,8 +2489,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " "argument shape: %s.", - i, ShapeUtil::HumanString(param_shape).c_str(), - ShapeUtil::HumanString(arg_shape).c_str()); + i, ShapeUtil::HumanString(param_shape), + ShapeUtil::HumanString(arg_shape)); } } @@ -2490,202 +2498,198 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } static Status ValidateGatherDimensionNumbers( - const Shape& input_shape, - tensorflow::gtl::ArraySlice gather_indices_shape, + const Shape& input_shape, absl::Span start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.output_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.output_window_dims()) != - dim_numbers.output_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.offset_dims()) != + dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.output_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } - const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); const int64 output_shape_rank = - output_window_dim_count + gather_indices_shape.size() - 1; + output_offset_dim_count + start_indices_shape.size() - 1; - for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { - int64 window_index = dim_numbers.output_window_dims(i); - if (window_index < 0 || window_index >= output_shape_rank) { + for (int i = 0; i < dim_numbers.offset_dims_size(); ++i) { + int64 offset_dim = dim_numbers.offset_dims(i); + if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Window index %d in gather op is out of bounds; got %lld, but should " - "have been in [0,%lld).", - i, window_index, output_shape_rank); + "Offset dimension %d in gather op is out of bounds; got %d, but " + "should " + "have been in [0,%d).", + i, offset_dim, output_shape_rank); } } - if (dim_numbers.gather_dims_to_operand_dims_size() != - gather_indices_shape[dim_numbers.index_vector_dim()]) { + if (dim_numbers.start_index_map_size() != + start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( - "Gather op has %d elements in gather_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of gather_indices is " - "%lld. These two numbers must be equal.", - dim_numbers.gather_dims_to_operand_dims_size(), - dim_numbers.index_vector_dim(), - gather_indices_shape[dim_numbers.index_vector_dim()]); + "Gather op has %d elements in start_index_map and the " + "bound of dimension index_vector_dim=%d of start_indices is " + "%d. These two numbers must be equal.", + dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), + start_indices_shape[dim_numbers.index_vector_dim()]); } - for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { - int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); - if (gather_dim_to_input_dim < 0 || - gather_dim_to_input_dim >= input_shape.dimensions_size()) { + for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { + int64 operand_dim_for_start_index_i = dim_numbers.start_index_map(i); + if (operand_dim_for_start_index_i < 0 || + operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", - input_shape.dimensions_size(), i, gather_dim_to_input_dim); + "Invalid start_index_map; domain is [0, %d), got: %d->%d.", + input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } - std::vector sorted_gather_dims_to_operand_dims( - dim_numbers.gather_dims_to_operand_dims().begin(), - dim_numbers.gather_dims_to_operand_dims().end()); + std::vector sorted_start_index_map( + dim_numbers.start_index_map().begin(), + dim_numbers.start_index_map().end()); - c_sort(sorted_gather_dims_to_operand_dims); + absl::c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != - sorted_gather_dims_to_operand_dims.end()) { + if (absl::c_adjacent_find(sorted_start_index_map) != + sorted_start_index_map.end()) { return InvalidArgument( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ")); } - for (int64 elided_dim : dim_numbers.elided_window_dims()) { - if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { + if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid elided_window_dims set in gather op; valid range is [0, " - "%d), got: %lld.", - input_shape.dimensions_size(), elided_dim); + "Invalid collapsed_slice_dims set in gather op; valid range is [0, " + "%d), got: %d.", + input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.elided_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( - "elided_window_dims in gather op must be sorted; got: %s", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + "collapsed_slice_dims in gather op must be sorted; got: %s", + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.elided_window_dims()) != - dim_numbers.elided_window_dims().end()) { + if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( - "Repeated dimensions not allowed in elided_window_dims in gather op; " + "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.elided_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } return Status::OK(); } /*static*/ StatusOr ShapeInference::InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds) { + absl::Span slice_sizes) { TF_RETURN_IF_ERROR( ExpectArray(input_shape, "input tensor operand gather op")); TF_RETURN_IF_ERROR( - ExpectArray(gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(start_indices_shape, "gather indices operand of gather op")); - if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(gather_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape)); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if // index_vector_dim is rank(P). The bounds of this expanded shape is - // stored in expanded_gather_indices_shape. + // stored in expanded_start_indices_shape. - if (gather_indices_shape.dimensions_size() < + if (start_indices_shape.dimensions_size() < gather_dim_numbers.index_vector_dim() || gather_dim_numbers.index_vector_dim() < 0) { return InvalidArgument( - "Gather index leaf dimension must be within [0, rank(gather_indices) + " - "1). rank(gather_indices) is %d and gather index leaf dimension is " - "%lld.", - gather_indices_shape.dimensions_size(), + "Gather index leaf dimension must be within [0, rank(start_indices) + " + "1). rank(start_indices) is %d and gather index leaf dimension is " + "%d.", + start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } - std::vector expanded_gather_indices_shape; - expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(expanded_gather_indices_shape)); - if (expanded_gather_indices_shape.size() == + std::vector expanded_start_indices_shape; + expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); + absl::c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); + if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { - expanded_gather_indices_shape.push_back(1); + expanded_start_indices_shape.push_back(1); } TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( - input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + input_shape, expanded_start_indices_shape, gather_dim_numbers)); - if (window_bounds.size() != input_shape.dimensions_size()) { + if (slice_sizes.size() != input_shape.dimensions_size()) { return InvalidArgument( - "Gather op must have one window bound for every input dimension; got: " - "len(window_bounds)=%lu, input_shape.rank=%d.", - window_bounds.size(), input_shape.dimensions_size()); + "Gather op must have one slice size for every input dimension; got: " + "len(slice_sizes)=%lu, input_shape.rank=%d.", + slice_sizes.size(), input_shape.dimensions_size()); } - if (window_bounds.size() != - gather_dim_numbers.output_window_dims_size() + - gather_dim_numbers.elided_window_dims_size()) { + if (slice_sizes.size() != + gather_dim_numbers.offset_dims_size() + + gather_dim_numbers.collapsed_slice_dims_size()) { return InvalidArgument( - "All components of the window index in a gather op must either be a " - "output window index or explicitly elided; got len(window_bounds)=%lu, " - "output_window_bounds=%s, elided_window_bounds=%s.", - window_bounds.size(), - Join(gather_dim_numbers.output_window_dims(), ",").c_str(), - Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + "All components of the offset index in a gather op must either be a " + "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " + "output_slice_sizes=%s, collapsed_slice_dims=%s.", + slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); } - for (int i = 0; i < window_bounds.size(); i++) { - int64 window_bound = window_bounds[i]; - int64 corresponding_input_bound = input_shape.dimensions(i); - if (window_bound < 0 || window_bound > corresponding_input_bound) { + for (int i = 0; i < slice_sizes.size(); i++) { + int64 slice_size = slice_sizes[i]; + int64 corresponding_input_size = input_shape.dimensions(i); + if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( - "Window bound at index %d in gather op is out of range, must be " - "within " - "[0, %lld), got %lld.", - i, corresponding_input_bound + 1, window_bound); + "Slice size at index %d in gather op is out of range, must be " + "within [0, %d), got %d.", + i, corresponding_input_size + 1, slice_size); } } - for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { - if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + for (int i = 0; i < gather_dim_numbers.collapsed_slice_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( - "Gather op can only elide window indices with bound 1, but bound is " - "%lld for index %lld at position %d.", - window_bounds[gather_dim_numbers.elided_window_dims(i)], - gather_dim_numbers.elided_window_dims(i), i); + "Gather op can only collapse slice dims with bound 1, but bound is " + "%d for index %d at position %d.", + slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], + gather_dim_numbers.collapsed_slice_dims(i), i); } } - int64 result_rank = gather_dim_numbers.output_window_dims_size() + - (expanded_gather_indices_shape.size() - 1); - int64 window_dims_seen = 0; + int64 result_rank = gather_dim_numbers.offset_dims_size() + + (expanded_start_indices_shape.size() - 1); + int64 offset_dims_seen = 0; int64 gather_dims_seen = 0; std::vector output_dim_bounds; output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; bool is_window_index = - c_binary_search(gather_dim_numbers.output_window_dims(), i); + absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.elided_window_dims(), - window_dims_seen)) { - window_dims_seen++; + while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { + offset_dims_seen++; } - current_bound = window_bounds[window_dims_seen++]; + current_bound = slice_sizes[offset_dims_seen++]; } else { if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { gather_dims_seen++; } - current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + current_bound = expanded_start_indices_shape[gather_dims_seen++]; } output_dim_bounds.push_back(current_bound); @@ -2697,48 +2701,47 @@ static Status ValidateGatherDimensionNumbers( namespace { Status ValidateScatterDimensionNumbers( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice scatter_indices_shape, + const Shape& operand_shape, absl::Span scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.update_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.update_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( "Invalid update_window_dims set in scatter op; valid range is [0, " - "%lld). got: %lld.", + "%d). got: %d.", updates_rank, window_dim); } } // Validate inserted_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid inserted_window_dims set in scatter op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", operand_shape.dimensions_size(), inserted_dim); } } @@ -2748,7 +2751,7 @@ Status ValidateScatterDimensionNumbers( scatter_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "bound of dimension index_vector_dim=%d of scatter_indices is %d. " "These two numbers must be equal.", dim_numbers.scatter_dims_to_operand_dims_size(), dim_numbers.index_vector_dim(), @@ -2761,20 +2764,20 @@ Status ValidateScatterDimensionNumbers( scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", + "got: %d->%d.", operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); } } std::vector sorted_scatter_dims_to_operand_dims( dim_numbers.scatter_dims_to_operand_dims().begin(), dim_numbers.scatter_dims_to_operand_dims().end()); - c_sort(sorted_scatter_dims_to_operand_dims); - if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + absl::c_sort(sorted_scatter_dims_to_operand_dims); + if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) != sorted_scatter_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } return Status::OK(); @@ -2795,7 +2798,7 @@ Status ValidateScatterDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( "Scatter indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(scatter_indices_shape).c_str()); + ShapeUtil::HumanString(scatter_indices_shape)); } if (scatter_indices_shape.dimensions_size() < @@ -2804,7 +2807,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Scatter index leaf dimension must be within [0, rank(scatter_indices)" " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " - "is %lld.", + "is %d.", scatter_indices_shape.dimensions_size(), scatter_dim_numbers.index_vector_dim()); } @@ -2826,7 +2829,7 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { - return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + return InvalidArgument("Updates tensor must be of rank %d; got %d.", expected_updates_rank, ShapeUtil::Rank(updates_shape)); } @@ -2836,32 +2839,32 @@ Status ValidateScatterDimensionNumbers( scatter_dim_numbers)); int64 inserted_dims_seen = 0; - std::vector max_update_window_bounds; + std::vector max_update_slice_sizes; for (int i = 0; i < operand_shape.dimensions_size(); ++i) { if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; } else { - max_update_window_bounds.push_back(operand_shape.dimensions(i)); + max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } } for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { auto update_window_dim = scatter_dim_numbers.update_window_dims(i); if (updates_shape.dimensions(update_window_dim) > - max_update_window_bounds[i]) { + max_update_slice_sizes[i]) { return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " - "%lld, updates bound is %lld, operand bound is %lld.", + "%d, updates bound is %d, operand bound is %d.", update_window_dim, updates_shape.dimensions(update_window_dim), - max_update_window_bounds[i]); + max_update_slice_sizes[i]); } } int64 scatter_dims_seen = 0; for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { bool is_update_window_dim = - c_binary_search(scatter_dim_numbers.update_window_dims(), i); + absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { continue; } @@ -2873,8 +2876,8 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " - "scatter dimension %lld, updates bound is %lld, scatter_indices " - "bound is %lld.", + "scatter dimension %d, updates bound is %d, scatter_indices " + "bound is %d.", i, updates_shape.dimensions(i), expanded_scatter_indices_shape[scatter_dims_seen]); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c185b0a1bd79e23e0d76daad50fb4a9708a743dd..a28345acefb8fca1c8b6444f431f932c23c57ce4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -21,12 +21,12 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -55,7 +55,7 @@ class ShapeInference { // given input shapes. static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); static StatusOr InferBinaryOpShape(HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs); @@ -73,18 +73,15 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operand_shapes); + HloOpcode opcode, absl::Span operand_shapes); static StatusOr InferVariadicOpShape( - HloOpcode opcode, - tensorflow::gtl::ArraySlice operands); + HloOpcode opcode, absl::Span operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. static StatusOr InferMapShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply, - tensorflow::gtl::ArraySlice dimensions); + absl::Span arg_shapes, const ProgramShape& to_apply, + absl::Span dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. @@ -112,17 +109,17 @@ class ShapeInference { // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1); // Infers the shape produced by the given FFT type on the given operand. - static StatusOr InferFftShape( - const Shape& in, FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); + static StatusOr InferFftShape(const Shape& in, FftType fft_type, + absl::Span fft_length); // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferCrossReplicaSumShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); // Infers final shape of an Alltoall operation that is created by the xla // builder. @@ -133,7 +130,10 @@ class ShapeInference { // Infers the shape of an HLO all-to-all instruction. static StatusOr InferAllToAllTupleShape( - tensorflow::gtl::ArraySlice operand_shapes); + absl::Span operand_shapes); + + // Infers the shape of a collective permute operation. + static StatusOr InferCollectivePermuteShape(const Shape& shape); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -142,8 +142,8 @@ class ShapeInference { // index as the leading parameter, and the program shape should match // accordingly (or an error will result). static StatusOr InferReduceShape( - tensorflow::gtl::ArraySlice arg_shapes, - tensorflow::gtl::ArraySlice dimensions_to_reduce, + absl::Span arg_shapes, + absl::Span dimensions_to_reduce, const ProgramShape& to_apply); // Infers the shape produced by applying the given computation to the operand @@ -161,24 +161,23 @@ class ShapeInference { // Infers the shape produced by a reverse operation that reverses the order // of the elements in the given dimensions. - static StatusOr InferReverseShape( - const Shape& operand_shape, - tensorflow::gtl::ArraySlice dimensions); + static StatusOr InferReverseShape(const Shape& operand_shape, + absl::Span dimensions); // Infers the shape produced by a slice operation spanning from the starts to // the limits in the original shape's dimensions. // // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16] - static StatusOr InferSliceShape( - const Shape& arg, tensorflow::gtl::ArraySlice starts, - tensorflow::gtl::ArraySlice limits, - tensorflow::gtl::ArraySlice strides); + static StatusOr InferSliceShape(const Shape& arg, + absl::Span starts, + absl::Span limits, + absl::Span strides); // Infers the shape produced by a dynamic slice operation of size specified // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'. static StatusOr InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, - tensorflow::gtl::ArraySlice slice_sizes); + absl::Span slice_sizes); // Infers the shape produced by a dynamic update slice operation based // on the shape of operand and update. @@ -209,30 +208,30 @@ class ShapeInference { // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( - const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); + const Shape& operand, absl::Span broadcast_sizes); // Infers the shape produced by a reshape operation from the element type of // its operand and the new dimension sizes specified. - static StatusOr InferReshapeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + static StatusOr InferReshapeShape(const Shape& operand, + absl::Span dimensions, + absl::Span new_sizes); // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( - const Shape& operand, tensorflow::gtl::ArraySlice dimensions); + const Shape& operand, absl::Span dimensions); // Helper that infers the shape produced by performing a concatenate operation // with the given operand shapes. static StatusOr InferConcatOpShape( - tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + absl::Span arg_shapes, int64 dimension); // Infers the shape produced by a kAfterAll. Trivially this shape is always a // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes // and checking operand shapes. This method verifies that the operand shapes // are all TOKENs. static StatusOr InferAfterAllShape( - tensorflow::gtl::ArraySlice arg_shapes); + absl::Span arg_shapes); // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that @@ -262,8 +261,7 @@ class ShapeInference { // Helper that validates the given arg_shapes are compatible with the shape of // the to_apply parameters, and returns the to_apply result shape. static StatusOr InferCallShape( - tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply); + absl::Span arg_shapes, const ProgramShape& to_apply); // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. @@ -275,9 +273,9 @@ class ShapeInference { // with the given input shape, gather indices shape and gather dimension // numbers. static StatusOr InferGatherShape( - const Shape& input_shape, const Shape& gather_indices_shape, + const Shape& input_shape, const Shape& start_indices_shape, const GatherDimensionNumbers& gather_dim_numbers, - tensorflow::gtl::ArraySlice window_bounds); + absl::Span slice_sizes); // Helper that validates the given input shape, scatter indices shape, updates // shape, and scatter dimension numbers that constitute a scatter operation, @@ -295,7 +293,7 @@ class ShapeInference { // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( HloOpcode operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); // Helper for inferring the shape of Clamp ops. static StatusOr InferClampShape(const Shape& min, const Shape& operand, @@ -323,7 +321,7 @@ class ShapeInference { // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( const Shape& smaller_shape, const Shape& larger_shape, - tensorflow::gtl::ArraySlice broadcast_dimensions); + absl::Span broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); }; diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index a73fa181cdd13dc7fabcdc367ae117e19bdc3e5f..cc92e58ef867ee716714fff4fdab07b9cb836d00 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,18 +17,17 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -58,9 +57,9 @@ class ReduceShapeInferenceTest : public ShapeInferenceTest { // Helper that runs reduce shape inference with the input 'arg' and given // dimensions to reduce, and checks the inferred shape is as expected. The // element type here is hard-coded to F32. - void ExpectInferredReduceShape( - const Shape& expected_inferred_shape, const Shape& arg, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { + void ExpectInferredReduceShape(const Shape& expected_inferred_shape, + const Shape& arg, + absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&arg, &f32_}, dimensions_to_reduce, to_apply); @@ -252,7 +251,7 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, - const tensorflow::gtl::ArraySlice& bcast) { + const absl::Span& bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; @@ -1654,11 +1653,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); @@ -1669,11 +1668,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{1}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1684,11 +1683,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{4}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1700,11 +1699,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) @@ -1717,11 +1716,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1735,11 +1734,11 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, @@ -1749,16 +1748,15 @@ TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape( - f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3, 4}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/0), - /*window_bounds=*/{30, 29, 28, 27, 26})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{0, 1, 2, 3, 4}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) @@ -1772,11 +1770,11 @@ TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0, 1, 2, 3}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, + /*offset_dims=*/{0, 1, 2, 3}, + /*collapsed_slice_dims=*/{0}, + /*start_index_map=*/{0}, /*index_vector_dim=*/0), - /*window_bounds=*/{1, 30, 29, 28, 27})); + /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) @@ -1787,11 +1785,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) @@ -1802,11 +1800,11 @@ TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) @@ -1817,11 +1815,11 @@ TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, + /*offset_dims=*/{0}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, /*index_vector_dim=*/0), - /*window_bounds=*/{64, 1}); + /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) @@ -1833,11 +1831,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 8, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 8, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1850,11 +1848,11 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 7}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 7}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), @@ -1867,14 +1865,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 99, 100, 101}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 99, 100, 101}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 2 in gather op is out of bounds")) + HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } @@ -1883,14 +1881,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 9}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 9}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window index 4 in gather op is out of bounds")) + HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } @@ -1899,16 +1897,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{4}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{4}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr("All components of the window index in a gather op must either " - "be a output window index or explicitly elided")) + HasSubstr("All components of the offset index in a gather op must either " + "be a offset dimension or explicitly collapsed")) << statusor.status(); } @@ -1917,14 +1915,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 19}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Invalid elided_window_dims set in gather op; valid " + HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } @@ -1934,16 +1932,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{0, 1, 2, 3, 3}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr( - "Repeated dimensions not allowed in elided_window_dims in gather op")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Repeated dimensions not allowed in " + "collapsed_slice_dims in gather op")) << statusor.status(); } @@ -1952,17 +1949,16 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " - "the bound of dimension index_vector_dim=4 of " - "gather_indices is 5. These two numbers must be equal.")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in start_index_map and " + "the bound of dimension index_vector_dim=4 of " + "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } @@ -1971,16 +1967,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT( - statusor.status().error_message(), - HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " - "[0, 5), got: 4->7")) + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } @@ -1989,16 +1983,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } @@ -2007,14 +2000,14 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{2, 1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{2, 1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{1, 1, 28, 27, 26}); + /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("elided_window_dims in gather op must be sorted")) + HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } @@ -2023,15 +2016,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{2}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{2}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 1, 300, 26}); + /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Window bound at index 3 in gather op is out of range, " - "must be within [0, 48), got 300")) + HasSubstr("Slice size at index 3 in gather op is out of range, " + "must be within [0, 48), got 300.")) << statusor.status(); } @@ -2040,16 +2033,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26}); + /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), - HasSubstr( - "Gather op must have one window bound for every input dimension")) + HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } @@ -2058,15 +2050,15 @@ TEST_F(ScatterGatherShapeInferenceTest, StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*window_bounds=*/{30, 29, 28, 26, 20}); + /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Gather op can only elide window indices with bound 1, " - "but bound is 29 for index 1 at position 0")) + HasSubstr("Gather op can only collapse slice dims with bound 1, " + "but bound is 29 for index 1 at position 0.")) << statusor.status(); } @@ -2074,16 +2066,16 @@ TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4, 5, 6, 7, 8}, - /*elided_window_dims=*/{}, - /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*offset_dims=*/{4, 5, 6, 7, 8}, + /*collapsed_slice_dims=*/{}, + /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), - /*window_bounds=*/{30, 29, 28, 27, 26}); + /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " - "rank(gather_indices) + 1)")) + "rank(start_indices) + 1)")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 7d7dcac10b65933d1c81b8aca77465932694bfdb..921a984589bb4fb64058a2a56adfe84fe14af69b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,20 +18,19 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::Appendf; - ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const se::Platform* platform, int device_ordinal) @@ -76,7 +75,7 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = tensorflow::strings::StrCat( + string s = absl::StrCat( "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), ", on-device shape=" + @@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const { shape_str = ShapeUtil::HumanStringWithLayout(subshape); } const se::DeviceMemoryBase& memory = buffer(index); - Appendf(&s, " %s%p (%lld bytes) : %s\n", - string(index.size() * 2, ' ').c_str(), memory.opaque(), - memory.size(), shape_str.c_str()); + absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n", + string(index.size() * 2, ' '), memory.opaque(), + memory.size(), shape_str); }); return s; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index 905a7e82e621f2bf4588b71be5dbab20f892cafe..e1d26da4a20c0105be304b1a34c81515fcdc6b7f 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -84,6 +84,14 @@ class ShapedBuffer { *buffers_.mutable_element(index) = buffer; } + // Sets all buffers. + // + // Precondition: buffers.shape == on_device_shape_ + void set_buffers(ShapeTree buffers) { + CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_)); + buffers_ = std::move(buffers); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 0fc243667911651c788e3c1e5f1d39d86170f1ad..d69e6362e91e4696dab3c46d99a981c67b593a1c 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { xla::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; - auto scoped_buffer = tensorflow::MakeUnique( + auto scoped_buffer = absl::make_unique( shape, shape, &allocator, kDeviceOrdinal); std::unique_ptr buffer = std::move(scoped_buffer); buffer = nullptr; diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc index 8cbaac7b3760717bcacb57adc8782a5755c0aa6d..dd53c7531bea4273b5f8dc1c993e7720eb1afeb2 100644 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ b/tensorflow/compiler/xla/service/source_map_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,11 +27,10 @@ Status InvalidParameterArgumentV(const OpMetadata& op_metadata, string message; tensorflow::strings::Appendv(&message, format, args); if (!op_metadata.source_file().empty()) { - tensorflow::strings::Appendf(&message, " (%s:%d)", - op_metadata.source_file().c_str(), - op_metadata.source_line()); + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); } } // namespace diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index 18e2651abb1600a7b9ffb79de887b8795717e55e..c5a7e17cb44c2b3b5ef145da0d66b4b3160f9531 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -23,6 +24,19 @@ limitations under the License. namespace xla { namespace source_map_util { +// Creates an INVALID_ARGUMENT status with the given format string. +template +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const absl::FormatSpec& format, + const Args&... args) { + string message = absl::StrFormat(format, args...); + if (!op_metadata.source_file().empty()) { + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message); +} + // Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable @@ -30,17 +44,21 @@ namespace source_map_util { // // executable may be nullptr, but parameter_number should not be out of bounds // or a CHECK-failure may occur. +template Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(3, 4); - -// As above, but takes the parameter metadata directly instead of extracting it -// from the executable. -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(2, 3); + const absl::FormatSpec& format, + const Args&... args) { + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + return InvalidParameterArgument(metadata, format, args...); + } + return InvalidArgument(format, args...); +} } // namespace source_map_util } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index c0582c6a2d3a05e2ed5aead5faac54e536d350cd..5d1cd1c4422a10e3b9e6ce6fac2c83594bb58b30 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { if (!stream) { // Create a new stream. - stream = MakeUnique(executor); + stream = absl::make_unique(executor); stream->Init(); VLOG(1) << stream->DebugStreamPointers() << " StreamPool created new stream"; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 32d368a90429ec026120bdf033957617eeaba23e..b8d2d546e5d4dc67e3f314dfc6dcd4e8df5451c5 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/notification.h" -using ::tensorflow::strings::StrCat; +using absl::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -61,7 +63,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } Status TransferManager::TransferLiteralFromDevice( @@ -120,7 +122,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return MakeUnique(std::move(literal)); + return absl::make_unique(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -147,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync( if (dest.size() < GetByteSizeRequirement(on_device_shape)) { return FailedPrecondition( "Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, @@ -164,12 +166,12 @@ void TransferManager::TransferArrayFromDevice( auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); - return done(FailedPrecondition("%s", error.c_str())); + return done(FailedPrecondition("%s", error)); } if (source.size() < GetByteSizeRequirement(shape)) { return done( FailedPrecondition("Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, @@ -201,7 +203,7 @@ void TransferManager::TransferArrayFromDevice( return NotFound( "could not find registered transfer manager for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.manager == nullptr) { @@ -252,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice( if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", source.size(), size); } stream->ThenMemcpy(destination, source, size); @@ -265,7 +267,7 @@ Status TransferManager::TransferBufferToDevice( if (destination->size() < size) { return FailedPrecondition( "Destination allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", destination->size(), size); } stream->ThenMemcpy(destination, source, size); @@ -276,9 +278,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { - return InvalidArgument( - "Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape)); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 475a2e5c141d66fa689fb402da1ee81fb4ab80f7..21725946b3629a4495d8ad6cc1529d712d22e0af 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -20,12 +20,12 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class TransferManager { // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice executor) = 0; + absl::Span executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the @@ -152,6 +152,26 @@ class TransferManager { const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // The given ShapedBuffer holds a handle to allocated memory, but it is not + // in the general case legal to immediately copy or access that allocated + // memory because queued operations on the device may alias that memory. + // Memory ordering is enforced by the Stream's happens-before relationship + // which allows eager deallocation and reallocation of buffers host-side even + // if the device hasn't finished with them. + // + // In certain cases, it can be known that a ShapedBuffer does not have any + // conflicting accesses on the device and thus is eligible to be accessed at + // any time from the host. + // + // This function returns true if device_buffer can be accessed immediately + // without waiting for the Stream's previously enqueued items. This only + // returns true if all subbuffers in device_buffer can be accessed + // immediately. + virtual bool CanShapedBufferBeAccessedNow( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { + return false; + } + ///// // The TransferManager class also serves as a point to register objects for // the various platforms. @@ -191,8 +211,7 @@ class TransferManager { // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::Stream* stream, - tensorflow::gtl::ArraySlice elements, + se::Stream* stream, absl::Span elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; private: diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 49e1f873192f800056a2272f7d4f698898b0f8a1..530f40e4b2f9c7c19fa29dad28a077b9d4d68a71 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { std::unique_ptr new_dot = HloInstruction::CreateDot( dot->shape(), new_lhs, new_rhs, new_dim_numbers); + new_dot->set_precision_config(dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); + new_conv->set_precision_config(convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 71e8446452f072c22bb730cbda65a1743a95cd4c..3e5aa2db60ee31d9fbccf8f7256b15c1b8465335 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface { explicit TransposeFolding( TransposableGemmOperandsFn transposable_gemm_operands, TransposableConvOperandsFn transposable_conv_operands); - tensorflow::StringPiece name() const override { return "transpose-folding"; } + absl::string_view name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0447807a41b8b32ee297e1ca94393da8c687c5e6..6fed7c76d04ad5d8236fecd07aa27f1eda221ea7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -26,17 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "])"); + return absl::StrCat("BufferAlias(", instruction_->name(), "[", + absl::StrJoin(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -360,7 +360,7 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { } Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); + absl::Span operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); points_to_set.AddPointedToBuffer( logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), @@ -441,7 +441,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( PerInstruction* pi = PerInst(instruction); CHECK(pi->points_to_set == nullptr) << "instruction should not have been present in the map."; - auto set = MakeUnique(&instruction->shape()); + auto set = absl::make_unique(&instruction->shape()); pi->points_to_set = std::move(set); // Return *set using the iterator returned by emplace. return *pi->points_to_set; @@ -462,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { return FailedPrecondition( "LogicalBuffer %s is ill-defined: instruction %s does not define a " "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + buffer.ToString(), buffer.instruction()->name()); } } if (buffer.id() < 0 || buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: invalid id %lld", - buffer.ToString().c_str(), buffer.id()); + return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", + buffer.ToString(), buffer.id()); } if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || GetBuffer(buffer.id()).index() != buffer.index()) { return FailedPrecondition( "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", - buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + buffer.ToString(), GetBuffer(buffer.id()).ToString()); } return Status::OK(); @@ -495,8 +494,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name(), absl::StrJoin(index, ",")); } return buffers[0]; } @@ -557,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( } string TuplePointsToAnalysis::ToString() const { - string output = tensorflow::strings::Printf( - "TuplePointsToSet for module %s:\n", module_->name().c_str()); + string output = + absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; - tensorflow::strings::StrAppend(&output, entry, "computation ", - computation->name(), ":\n"); + absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); @@ -575,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const { } } - tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { - tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n"); + absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { - tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(), - "\n"); + absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } return output; @@ -589,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const { void TuplePointsToAnalysis::InstructionToString( const HloInstruction* instruction, string* output) const { const string prefix = instruction->IsFused() ? " " : ""; - tensorflow::strings::StrAppend(output, prefix, " instruction ", - instruction->ToShortString(), ":\n"); + absl::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement([&prefix, &output]( const ShapeIndex& index, const PointsToSet::BufferList& points_to) { - tensorflow::strings::StrAppend( - output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); + absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", + absl::StrJoin(points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); }); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 686bb053288fbd6a46ca50a2c65c739354fd2678..a9e8a51e0923362162c6b8a2e97fc334e56d4329 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/compactptrset.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -109,7 +110,7 @@ class PointsToSet { // Add a tuple source instruction for the given index. void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); - using BufferList = tensorflow::gtl::InlinedVector; + using BufferList = absl::InlinedVector; // Return the list of logical buffers for the subshape at index. const BufferList& element(const ShapeIndex& index) const { @@ -203,7 +204,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // logical buffer The buffer alias set is the inverse of the points-to set. // That is, LogicalBuffer B is in the points-to set of instruction I at index // N iff instruction I, index N is a BufferAlias of B. - using BufferAliasVector = tensorflow::gtl::InlinedVector; + using BufferAliasVector = absl::InlinedVector; const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; // Returns the number of logical buffers in the module @@ -226,8 +227,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // instructions produce a single buffer (the top-level buffer), some produce // no buffers (eg bitcast), and some produce more than one buffer (eg, // tuple-shaped parameters). - using BufferDefinitionVector = - tensorflow::gtl::InlinedVector; + using BufferDefinitionVector = absl::InlinedVector; const BufferDefinitionVector& GetBuffersDefinedByInstruction( const HloInstruction* instruction) const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 10d382e8abc92145c1804cbf18bbed714fa34571..a32d1f9026e8beae77b5b40241995707ff62231e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -72,9 +72,8 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Checks that the given points-to set contains exactly (unordered) the given // LogicalBuffers. - void ExpectHasBuffers( - const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice buffers) { + void ExpectHasBuffers(const PointsToSet::BufferList& points_to_set, + absl::Span buffers) { std::vector vec(buffers.begin(), buffers.end()); EXPECT_THAT(points_to_set, UnorderedElementsAreArray(vec)); } @@ -83,7 +82,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // top-level buffers of the given instructions. void ExpectHasTopLevelBuffers( const PointsToSet::BufferList& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { PointsToSet::BufferList buffers; for (auto instruction : instructions) { buffers.push_back(GetBuffer(instruction, /*index=*/{})); @@ -94,7 +93,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // Overload which takes a set instead of a vector. void ExpectHasTopLevelBuffers( const PointsToSet::BufferSet& points_to_set, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { ExpectHasTopLevelBuffers( PointsToSet::BufferList(points_to_set.begin(), points_to_set.end()), instructions); @@ -104,8 +103,7 @@ class TuplePointsToAnalysisTest : public HloTestBase { // aliases which are exactly (unordered) the given instruction/index pairs. void ExpectHasBufferAliases( const HloInstruction* instruction, const ShapeIndex& index, - tensorflow::gtl::ArraySlice> - expected) { + absl::Span> expected) { const LogicalBuffer* buffer = points_to_analysis_->GetBufferDefinedAt(instruction, index) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 750950188312c5077d487f2feef0606f07839432..8c91d6e69de637d58fa2ffc1a32ea65f09d3b6d8 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface { TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} - tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + absl::string_view name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc index 4a530bb0b20582b303f4af969514748b46fd5064..cfb0c787d09557fd1aec3517eb9698cfec323369 100644 --- a/tensorflow/compiler/xla/service/tuple_util.cc +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/tuple_util.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -40,7 +40,7 @@ namespace xla { /*static*/ HloInstruction* TupleUtil::AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values) { + absl::Span trailing_values) { CHECK(ShapeUtil::IsTuple(input_tuple->shape())); HloComputation* computation = input_tuple->parent(); diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h index e5ff9aaa8357fe8e4777d6dee37bbec72e144c06..bc5aac09f270c01515b1f3a704af6949f24cb218 100644 --- a/tensorflow/compiler/xla/service/tuple_util.h +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -38,7 +38,7 @@ class TupleUtil { // `input_tuple`. static HloInstruction* AppendSuffix( HloInstruction* input_tuple, - tensorflow::gtl::ArraySlice trailing_values); + absl::Span trailing_values); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index af2cb6dc2a3f4a004351acc62796e0daf46719c2..c3c2603c7eb58d3e57346d2ea1e0058f8e5d7fe8 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -18,8 +18,8 @@ limitations under the License. namespace xla { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; +using absl::nullopt; +using absl::optional; // Finds and returns the non-constant operand in instr. // @@ -211,8 +211,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == - tensorflow::gtl::ArraySlice{false}) { + if (result.ValueOrDie()->data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf59813e8c405a8709446bf8457729348ceae4ec..bf497f4892b95c927379411468a66d8961465413 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -25,8 +25,8 @@ namespace xla { // nullopt otherwise. max_value_returned limits the number of steps that are // evaluated while trying to brute force a loop trip count, trip counts larger // than max_value_returned result in nullopt. -tensorflow::gtl::optional ComputeWhileLoopTripCount( - HloInstruction *while_op, int64 max_value_returned = 128); +absl::optional ComputeWhileLoopTripCount(HloInstruction *while_op, + int64 max_value_returned = 128); } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 62af45128ad2fb7bf886bef78ec3ab42529a181e..aab11806621746141f4302f39a780fcdbab99fc1 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( std::vector users; users.reserve(old_instr->user_count()); - c_copy(old_instr->users(), std::back_inserter(users)); + absl::c_copy(old_instr->users(), std::back_inserter(users)); for (auto* user : users) { for (int64 i = 0, e = user->operand_count(); i < e; i++) { @@ -108,10 +109,10 @@ StatusOr WhileLoopConstantSinking::Run(HloModule* module) { // // This will let us sink the constant into the outer while first and then // into the inner while in a single run of this pass. - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 21fb8568a84985692026e145c363500a154a1599..2dba7d7f7574742a301e3503e353bbe57d72a203 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface { public: ~WhileLoopConstantSinking() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 266039d2ff8ef4befba0d1023ac1914737207d4f..0e7667de832c54f647d071e3c9563091d0f994aa 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -206,7 +206,8 @@ body { p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 - outfeed = token[] outfeed(p_body.0) + token = token[] after-all() + outfeed = token[] outfeed(p_body.0, token) ROOT root = (f32[2],f32[2],f32[2]) tuple(p_body.0, p_body.1, p_body.1) } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 09ddcffb22c2184262adf87d570870ec000c0e6f..e8fe33e62659ae0fffff1ad46e8ba77f715b76b2 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -14,18 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { +using absl::InlinedVector; using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -using tensorflow::gtl::InlinedVector; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy( }; InlinedVector new_operands; - c_transform(old_instruction->operands(), std::back_inserter(new_operands), - get_new_operand); + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); HloInstruction* new_instruction = parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( @@ -109,6 +110,7 @@ bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( case HloOpcode::kBitcast: case HloOpcode::kBroadcast: + case HloOpcode::kIota: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -197,7 +199,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( op->opcode() == HloOpcode::kConstant; }; - if (!c_all_of(instruction->operands(), is_invariant)) { + if (!absl::c_all_of(instruction->operands(), is_invariant)) { continue; } @@ -257,10 +259,10 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8e6cc8787576e4f041229da5cf8dd2b09194eb2a..2cdf20ce80362c0aeb9d8324573e7e9826cc018c 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface { : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index dd8697e680c56165f87c365a721eda2de1ebc085..6a7bfe3f129d97866ccc54897d584fab0f7c683e 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,17 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; +using absl::optional; // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. @@ -237,12 +236,11 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" - << tensorflow::str_util::Join( - user->users(), ", ", - [&](string* out, const HloInstruction* instr) { - tensorflow::strings::StrAppend( - out, instr->ToString(print_no_metadata)); - }) + << absl::StrJoin(user->users(), ", ", + [&](string* out, const HloInstruction* instr) { + absl::StrAppend( + out, instr->ToString(print_no_metadata)); + }) << "}"; replacements.emplace(user, nullptr); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 3d3e1d60f294c3a2574513c1c2f071805a341ad1..78024f14dc89ff40a11bbc3602072fda1fe6f312 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -33,9 +33,7 @@ namespace xla { class WhileLoopSimplifier : public HloPassInterface { public: ~WhileLoopSimplifier() override {} - tensorflow::StringPiece name() const override { - return "simplify-while-loops"; - } + absl::string_view name() const override { return "simplify-while-loops"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 2e1571943e537f772ee7dcd95c80ba540445b76e..1c892ba179ec67ccc9dbfe93d925551d6977ba15 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -64,10 +65,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } @@ -103,10 +102,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1ef17b9d7d2e769aadf39f8a70f78200b88e9d2c..f90ac91f9d07aded8cafccf82dae894c9a149bd1 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,15 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { @@ -93,7 +94,7 @@ WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { /*static*/ StatusOr WhileUtil::MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions) { + absl::Span instructions) { CHECK(ShapeUtil::IsTuple(while_instr->shape())); int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); @@ -206,7 +207,7 @@ static StatusOr MakeInitTupleFromInitValues( HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); init_values_with_indvar.push_back(zero); - c_copy(init_values, std::back_inserter(init_values_with_indvar)); + absl::c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( HloInstruction::CreateTuple(init_values_with_indvar)); } @@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); - c_transform(init_values, std::back_inserter(loop_state_shape_components), - [](HloInstruction* instr) { return instr->shape(); }); + absl::c_transform(init_values, + std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index e67636d80f4b682fe1335eae535fb86105ac082b..b1c4486887ae0ddbe2ba4e79f45a265689111017 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -55,7 +55,7 @@ class WhileUtil { // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, - tensorflow::gtl::ArraySlice instructions); + absl::Span instructions); using LoopStateTy = std::vector; using LoopBodyGeneratorTy = std::function( diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 2ccb919acf9c4e7c59a1ebaf36f42a6781068b5e..5e6941933330fde29bc9c779aae4bb3c36914660 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" @@ -206,7 +207,7 @@ ENTRY main { auto is_while = [](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kWhile; }; - EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); + EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 8763e588c484011ba2ccbc7cad8f29817347a605..a7f0e207eb5a81b04bb28977d6f5e38864ad2d6a 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -24,7 +24,7 @@ namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "zero_sized_hlo_elimination"; } }; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index caad31d6ce7ce35fa362ec364b0d7f1d95973715..d44db89d571891ecef554cd45c050017833982bb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -25,8 +25,8 @@ namespace xla { Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(other_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(other_shape), + ShapeUtil::HumanString(shape())); } shape_ = other_shape; return Status::OK(); @@ -35,8 +35,8 @@ Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*to_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(*to_shape), + ShapeUtil::HumanString(shape())); } *to_shape = shape_; return Status::OK(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index c74dd648addd70633edc2ec10a60879a00942716..52c895e8d4b2aa55b55df41b7139b00c576d6e99 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -21,16 +21,16 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -262,6 +262,25 @@ class ShapeTree { template Status ForEachMutableElementWithStatus(const Fn& func); + // Maps each element to generate a new tree with the same shape. + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachElement([&](const ShapeIndex& index, const T& t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + + template + ShapeTree Map(const std::function& func) { + ShapeTree result(shape_storage_); + ForEachMutableElement([&](const ShapeIndex& index, T* t) { + *result.mutable_element(index) = func(t); + }); + return result; + } + // Copy the subtree of values from 'other' rooted at ShapeIndex // 'source_base_index' into the subtree of value in this ShapeTree rooted at // 'target_base_index'. @@ -463,9 +482,6 @@ template ShapeTree::ShapeTree(Shape shape) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}); @@ -502,9 +518,6 @@ template ShapeTree::ShapeTree(Shape shape, const T& init_value) : shape_storage_(std::make_shared(std::move(shape))), shape_(shape_storage_.get()) { - // The shape_ field is just used to hold the structure of the shape. - // It should not be relied upon to store layout information. - LayoutUtil::ClearLayout(shape_storage_.get()); const int64 count = CountSubshapes(*shape_); nodes_.reserve(count); nodes_.emplace_back(ShapeIndex{}, init_value); diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index c4c958be4a18f23b8e34f9e619e447c6bf4334b5..c8ff55e7845785d9292516b823fb591cc28cbfad 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_tree.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -242,7 +243,7 @@ TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) { TEST_F(ShapeTreeTest, ShapeTreeOfNonCopyableType) { ShapeTree> shape_tree{tuple_shape_}; EXPECT_EQ(shape_tree.element({2}).get(), nullptr); - *shape_tree.mutable_element({2}) = MakeUnique(42); + *shape_tree.mutable_element({2}) = absl::make_unique(42); EXPECT_EQ(*shape_tree.element({2}), 42); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 34869cc5078699603c006387161fddd4fee4a9f8..9772c06bce32cef0d79a036b525c3606ea60e31b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,14 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/overflow_util.h" @@ -30,26 +38,22 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", absl::StrJoin(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { @@ -91,11 +95,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, } if (ShapeUtil::IsTuple(lhs)) { - return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - [=](const Shape& l, const Shape& r) { - return CompareShapes(l, r, compare_layouts, - ignore_fp_precision); - }); + return absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + [=](const Shape& l, const Shape& r) { + return CompareShapes(l, r, compare_layouts, + ignore_fp_precision); + }); } else if (!ShapeUtil::IsArray(lhs)) { // Non-tuple, non-array tupes such as opaque and token types are trivially // the same. @@ -107,13 +111,13 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, return false; } if (LayoutUtil::IsDenseArray(lhs)) { - if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), - LayoutUtil::MinorToMajor(rhs))) { + if (!absl::c_equal(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { VLOG(3) << "CompareShapes: lhs layout != rhs layout"; return false; } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { + if (!absl::c_equal(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { VLOG(3) << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; return false; @@ -135,15 +139,15 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts, // Constructs and returns the new shape with the given minor_to_major order in // its Layout. StatusOr MakeShapeWithLayoutInternal( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { if (dimensions.size() != minor_to_major.size()) { return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", dimensions.size(), minor_to_major.size()); } if (element_type == OPAQUE || element_type == TUPLE) { return InvalidArgument("Unsupported element type: %s", - PrimitiveType_Name(element_type).c_str()); + PrimitiveType_Name(element_type)); } Shape shape = ShapeUtil::MakeShape(element_type, dimensions); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); @@ -210,8 +214,8 @@ StatusOr MakeShapeWithLayoutInternal( return program_shape; } -/* static */ Shape ShapeUtil::MakeShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { +/* static */ Shape ShapeUtil::MakeShape(PrimitiveType element_type, + absl::Span dimensions) { CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); @@ -219,21 +223,21 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ Shape ShapeUtil::MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major) { + PrimitiveType element_type, absl::Span dimensions, + absl::Span minor_to_major) { return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { + PrimitiveType element_type, absl::Span dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + PrimitiveType element_type, absl::Span dimensions, int64 max_sparse_elements) { CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); @@ -252,9 +256,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return MakeShapeWithDescendingLayout(shape.element_type(), dims); } -/* static */ void ShapeUtil::PopulateShape( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - Shape* shape) { +/* static */ void ShapeUtil::PopulateShape(PrimitiveType element_type, + absl::Span dimensions, + Shape* shape) { shape->Clear(); shape->set_element_type(element_type); for (int64 dimension : dimensions) { @@ -264,8 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } -/* static */ Shape ShapeUtil::MakeTupleShape( - tensorflow::gtl::ArraySlice shapes) { +/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); result.mutable_tuple_shapes()->Reserve(shapes.size()); @@ -449,14 +452,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( namespace { // Class to memoize the computation of -// tensorflow::str_util::Lowercase(PrimitiveType_Name(p)) +// absl::AsciiStrToLower(PrimitiveType_Name(p)) // for all PrimitiveType values "p" class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = tensorflow::str_util::Lowercase( + lowercase_name_[i] = absl::AsciiStrToLower( PrimitiveType_Name(static_cast(i))); } } @@ -487,8 +490,7 @@ StatusOr StringToPrimitiveType(const string& name) { }(); auto found = name_to_type->find(name); if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", - name.c_str()); + return InvalidArgument("Invalid element type string: \"%s\".", name); } return found->second; } @@ -507,7 +509,7 @@ StatusOr StringToPrimitiveType(const string& name) { return text; } return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - tensorflow::str_util::Join(shape.dimensions(), ","), "]"); + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -543,30 +545,29 @@ StatusOr StringToPrimitiveType(const string& name) { : "(unknown)", ": ", HumanString(shape))); } - return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ", HumanString(program_shape.result())); } namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { - tensorflow::str_util::RemoveLeadingWhitespace(s); +StatusOr ParseShapeStringInternal(absl::string_view* s) { + *s = StripLeadingAsciiWhitespace(*s); - if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. + if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (tensorflow::str_util::ConsumePrefix(s, ")")) { + if (absl::ConsumePrefix(s, ")")) { break; } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !tensorflow::str_util::ConsumePrefix(s, ","); + *s = StripLeadingAsciiWhitespace(*s); + must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -575,9 +576,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { string dimensions_string; string format_string; string layout_string; - // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so + // absl::string_view is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our StringPiece type. + // amount from our string_view type. static LazyRE2 shape_pattern = { "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); @@ -585,12 +586,12 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); - auto string_to_int64 = [&s](const string& input) -> StatusOr { + auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { int64 element; - if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + if (!absl::SimpleAtoi(input, &element)) { return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), std::string(*s).c_str()); + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, + *s); } return element; }; @@ -598,7 +599,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { auto comma_list_to_int64s = [string_to_int64](const string& input) -> StatusOr> { std::vector results; - for (const string& piece : tensorflow::str_util::Split(input, ',')) { + for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } @@ -614,7 +615,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { StringToPrimitiveType(element_type_string)); if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string.c_str()); + element_type_string); } Shape result; @@ -644,17 +645,14 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return std::move(result); } - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); } } // namespace -/* static */ StatusOr ShapeUtil::ParseShapeString( - tensorflow::StringPiece s) { +/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", s); } return shape; } @@ -663,7 +661,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { const Shape& rhs) { CHECK(ShapeUtil::IsArray(lhs)); CHECK(ShapeUtil::IsArray(rhs)); - return ContainersEqual(lhs.dimensions(), rhs.dimensions()); + return absl::c_equal(lhs.dimensions(), rhs.dimensions()); } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { @@ -677,8 +675,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return IsArray(rhs) && SameDimensions(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringElementType); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -692,8 +690,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { CompatibleIgnoringElementType(lhs, rhs); } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), - CompatibleIgnoringFpPrecision); + absl::c_equal(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); } else { // Opaque, token, etc types are vacuously compatible. return lhs.element_type() == rhs.element_type(); @@ -792,7 +790,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ShortDebugString(); - tensorflow::gtl::ArraySlice padded_dimensions = + absl::Span padded_dimensions = LayoutUtil::PaddedDimensions(shape); if (!padded_dimensions.empty()) { CHECK_EQ(Rank(shape), padded_dimensions.size()); @@ -819,7 +817,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { const Shape& shape) { if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + shape.ShortDebugString()); } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { @@ -842,21 +840,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } return Status::OK(); } if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%lld " + "shape's rank is mismatched with dimension count; rank=%d " "dimensions_size=%d", Rank(shape), shape.dimensions_size()); } @@ -864,9 +862,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( - "shape's dimensions must not be < 0; dimension at index %lld was " - "%lld", - i, dimension); + "shape's dimensions must not be < 0; dimension at index %d was %d", i, + dimension); } } @@ -931,7 +928,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape_size < 0) { return InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } VLOG(3) << "Shape size is valid: " << shape_size; @@ -991,7 +988,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", - index.ToString().c_str(), shape.DebugString().c_str()); + index.ToString(), shape.DebugString()); } return_shape = &return_shape->tuple_shapes(i); } @@ -1014,12 +1011,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } @@ -1036,7 +1034,7 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { /* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { CHECK(ShapeUtil::IsArray(shape)); - return ArrayContains(AsInt64Slice(shape.dimensions()), 1); + return absl::c_linear_search(shape.dimensions(), 1); } namespace { @@ -1116,7 +1114,7 @@ Status ForEachMutableSubshapeHelper( } /* static */ Shape ShapeUtil::PermuteDimensions( - tensorflow::gtl::ArraySlice permutation, const Shape& shape) { + absl::Span permutation, const Shape& shape) { Shape new_shape = shape; new_shape.clear_dimensions(); for (auto dim : Permute(permutation, shape.dimensions())) { @@ -1171,8 +1169,7 @@ Status ForEachMutableSubshapeHelper( CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) << "shape=" << HumanStringWithLayout(shape) << ", new_shape=" << HumanStringWithLayout(new_shape) - << ", permutation={" << tensorflow::str_util::Join(permutation, ",") - << "}"; + << ", permutation={" << absl::StrJoin(permutation, ",") << "}"; } return new_shape; } @@ -1261,7 +1258,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping) { + absl::Span dimension_mapping) { CHECK(LayoutUtil::HasLayout(input_shape) && LayoutUtil::HasLayout(output_shape)); @@ -1288,7 +1285,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // apply(input_dimensions, I) = // apply((dimension_mapping * output_dimensions), I) // input_dimensions = dimension_mapping * output_dimensions - return ContainersEqual( + return absl::c_equal( ComposePermutations(dimension_mapping, AsInt64Slice(output_shape.layout().minor_to_major())), input_shape.layout().minor_to_major()); @@ -1459,7 +1456,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, check_input_unit_indices(output_shape, input_shape); } -/* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( +/* static */ absl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { CHECK(IsArray(input_shape)); CHECK(IsArray(output_shape)); @@ -1498,7 +1495,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (input_dimension_product < output_dimension_product || j == output_rank) { if (i == input_rank) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } dimension_to_alignment_index[i] = alignment.size() - 1; input_dimension_product *= input_shape.dimensions(i); @@ -1509,7 +1506,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } } if (input_dimension_product != output_dimension_product) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } // We also need to store an end element so that we know where the last // alignment part ends. @@ -1553,7 +1550,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, for (int64 j = 0; j < num_non_trivial_dimensions_in_alignment_part; ++i, ++j) { if (i == input_rank) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } // Skip trivial dimensions with a bound of 1. if (input_shape.dimensions(input_dimension_numbers[i]) == 1) { @@ -1566,7 +1563,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, if (dimension_to_alignment_index[input_dimension_numbers[i]] != current_alignment_index || input_dimension_numbers[i] > current_dimension_number) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } current_dimension_number = input_dimension_numbers[i]; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d6f17fc965d24bbbbd083b8dd0ec11a59e49ed4e..8234fcdd3f57978b94630d4e2880826dd678389f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,9 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -74,7 +74,7 @@ class ShapeIndex { // push_front is O(n^2), but shapes don't usually have a ton of dimensions. void push_front(int64 value) { indices_.insert(indices_.begin(), value); } - using container_type = tensorflow::gtl::InlinedVector; + using container_type = absl::InlinedVector; container_type::const_iterator begin() const { return indices_.begin(); } container_type::const_iterator end() const { return indices_.end(); } @@ -131,12 +131,12 @@ class ShapeIndexView { } ShapeIndexView ConsumeFront() const { ShapeIndexView result = *this; - result.indices_.pop_front(); + result.indices_.remove_prefix(1); return result; } ShapeIndexView ConsumeBack() const { ShapeIndexView result = *this; - result.indices_.pop_back(); + result.indices_.remove_suffix(1); return result; } ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } @@ -147,7 +147,7 @@ class ShapeIndexView { string ToString() const; private: - tensorflow::gtl::ArraySlice indices_; + absl::Span indices_; }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); @@ -228,7 +228,7 @@ class ShapeUtil { // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. - static StatusOr ParseShapeString(tensorflow::StringPiece s); + static StatusOr ParseShapeString(absl::string_view s); // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. @@ -328,7 +328,7 @@ class ShapeUtil { static Shape ChangeElementType(const Shape& original, PrimitiveType type); // Creates a tuple shape from a slice of element shapes within the tuple. - static Shape MakeTupleShape(tensorflow::gtl::ArraySlice shapes); + static Shape MakeTupleShape(absl::Span shapes); // Creates an opaque shape. These are generally used for threading a context // into a custom operation. @@ -355,31 +355,29 @@ class ShapeUtil { // Constructs a new shape with the given element type and sequence of // dimensions. static Shape MakeShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + absl::Span dimensions); // Creates a Shape with element type corresponding to T and the given // dimensions template - static Shape MakeShapeWithType( - tensorflow::gtl::ArraySlice dimensions) { + static Shape MakeShapeWithType(absl::Span dimensions) { return ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dimensions); } // Constructs a new shape with the given minor_to_major order in its Layout. // Returns a value shape such that shape.has_layout(). - static Shape MakeShapeWithLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice minor_to_major); + static Shape MakeShapeWithLayout(PrimitiveType element_type, + absl::Span dimensions, + absl::Span minor_to_major); - static Shape MakeShapeWithSparseLayout( - PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, - int64 max_sparse_elements); + static Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + absl::Span dimensions, + int64 max_sparse_elements); // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). static Shape MakeShapeWithDescendingLayout( - PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions); + PrimitiveType element_type, absl::Span dimensions); // Returns a new Shape based on the given Shape with low-dimension-major // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions @@ -391,8 +389,7 @@ class ShapeUtil { // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, - tensorflow::gtl::ArraySlice dimensions, - Shape* shape); + absl::Span dimensions, Shape* shape); // Validates that the provided shape satisfies invariants. static Status ValidateShape(const Shape& shape); @@ -539,7 +536,7 @@ class ShapeUtil { // !HasLayout(shape) || // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape), // InversePermutation(permutation)). - static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, + static Shape PermuteDimensions(absl::Span permutation, const Shape& shape); // If we can go from `shape_pre` to `shape_post` by merely inserting or @@ -580,9 +577,9 @@ class ShapeUtil { // to its input and thus may be replaced with a bitcast. // // Precondition: Both input_shape and output_shape have explicit layouts. - static bool TransposeIsBitcast( - const Shape& input_shape, const Shape& output_shape, - tensorflow::gtl::ArraySlice dimension_mapping); + static bool TransposeIsBitcast(const Shape& input_shape, + const Shape& output_shape, + absl::Span dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. @@ -597,8 +594,8 @@ class ShapeUtil { // layout). The layout of 'input_shape' is kept fixed. Returns // 'output_shape_with_layout' if such a layout can be found, and an error // otherwise. - static tensorflow::gtl::optional AlignLayouts( - const Shape& input_shape, const Shape& output_shape); + static absl::optional AlignLayouts(const Shape& input_shape, + const Shape& output_shape); // Returns a shape with the given dimension deleted. // For example: @@ -621,12 +618,12 @@ class ShapeUtil { // continue, or false otherwise. // // visitor_function must be a callable of type - // StatusOr(ArraySlice) or compatible. + // StatusOr(Span) or compatible. template static Status ForEachIndexWithStatus(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { return ForEachIndexInternal(shape, base, count, incr, visitor_function); } @@ -648,13 +645,12 @@ class ShapeUtil { } template - static void ForEachIndex(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + static void ForEachIndex(const Shape& shape, absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { ForEachIndexWithStatus(shape, base, count, incr, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -676,7 +672,7 @@ class ShapeUtil { template static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { ForEachIndexWithStatus(shape, - [&](tensorflow::gtl::ArraySlice indices) { + [&](absl::Span indices) { return StatusOr(visitor_function(indices)); }) .IgnoreError(); @@ -687,18 +683,18 @@ class ShapeUtil { // matter. // // visitor_function must be a callable of type - // void(ArraySlice) or compatible. + // void(Span) or compatible. template static void ForEachIndexParallel(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function) { // The parallel version of ForEachIndexInternal can never fail. CHECK(ForEachIndexInternal( shape, base, count, incr, - [&visitor_function](tensorflow::gtl::ArraySlice indexes) - -> StatusOr { + [&visitor_function]( + absl::Span indexes) -> StatusOr { visitor_function(indexes); return true; }, @@ -720,9 +716,9 @@ class ShapeUtil { template static Status ForEachIndexInternal(const Shape& shape, - tensorflow::gtl::ArraySlice base, - tensorflow::gtl::ArraySlice count, - tensorflow::gtl::ArraySlice incr, + absl::Span base, + absl::Span count, + absl::Span incr, const FnType& visitor_function, bool parallel = false) { if (ShapeUtil::IsZeroElementArray(shape)) { @@ -737,13 +733,13 @@ class ShapeUtil { int64 n = -1; std::vector indexes(base.begin(), base.end()); const int kNumThreads = tensorflow::port::NumSchedulableCPUs(); - tensorflow::gtl::optional pool; + absl::optional pool; if (parallel) { pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads); } while (n < rank) { - if (pool != tensorflow::gtl::nullopt) { + if (pool != absl::nullopt) { pool->Schedule( [indexes, &visitor_function] { visitor_function(indexes); }); } else { diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index e5dd62ae9a3dd9b961a7ae03a99c19220dbd43e7..6ca4085aaf3bd1c181da3b94aa6c570e21172d0a 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -23,8 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -705,11 +705,10 @@ TEST(ShapeUtilTest, ForEachIndex) { Shape shape = ShapeUtil::MakeShape(F32, data.dimensions); // Increments at every invocation. int invocations = 0; - auto increment_func = - [&invocations](tensorflow::gtl::ArraySlice indexes) { - invocations++; - return true; - }; + auto increment_func = [&invocations](absl::Span indexes) { + invocations++; + return true; + }; std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); @@ -726,8 +725,7 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) { // Increments at every invocation. int invocations = 0; auto increment_func = - [&invocations]( - tensorflow::gtl::ArraySlice indexes) -> StatusOr { + [&invocations](absl::Span indexes) -> StatusOr { if (++invocations == 5) { return Unimplemented("Cannot increment beyond 5."); } @@ -748,7 +746,7 @@ TEST(ShapeUtilTest, ForEachIndexParallel) { Shape shape = ShapeUtil::MakeShape(F32, {10, 10}); int64 output[10][10]; int init = 5; - auto set_func = [&](tensorflow::gtl::ArraySlice indexes) { + auto set_func = [&](absl::Span indexes) { output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1]; }; @@ -849,13 +847,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { std::iota(layout.begin(), layout.end(), 0); do { Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); - SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s))); std::vector permutation(3); std::iota(permutation.begin(), permutation.end(), 0); do { - SCOPED_TRACE(tensorflow::strings::StrCat( - "permutation=", tensorflow::str_util::Join(permutation, ","))); + SCOPED_TRACE( + absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); // TransposeIsBitcast takes the inverse of the permutation that // PermuteDimensions takes. diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc index 31844abd89a020c87c403353374a80fb639a3244..1c135dda864b3060b8bdc6369f18268d7c5c7f9e 100644 --- a/tensorflow/compiler/xla/sparse_index_array.cc +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -33,7 +33,7 @@ SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, } SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices) + absl::Span indices) : SparseIndexArray(max_indices, rank, std::vector(indices.begin(), indices.end())) {} @@ -48,25 +48,24 @@ int64 SparseIndexArray::index_count() const { return indices_.size() / rank_; } -tensorflow::gtl::ArraySlice SparseIndexArray::At( +absl::Span SparseIndexArray::At( int64 sparse_element_number) const { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::ArraySlice( + return absl::Span( indices_.data() + rank_ * sparse_element_number, rank_); } -tensorflow::gtl::MutableArraySlice SparseIndexArray::At( - int64 sparse_element_number) { +absl::Span SparseIndexArray::At(int64 sparse_element_number) { CHECK_GT(rank_, 0); CHECK_GE(sparse_element_number, 0); CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); - return tensorflow::gtl::MutableArraySlice( - indices_.data() + rank_ * sparse_element_number, rank_); + return absl::Span(indices_.data() + rank_ * sparse_element_number, + rank_); } -void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { +void SparseIndexArray::Append(absl::Span index) { CHECK_GT(rank_, 0); CHECK_EQ(index.size(), rank_); indices_.insert(indices_.end(), index.begin(), index.end()); @@ -90,12 +89,12 @@ bool SparseIndexArray::Validate(const Shape& shape) const { if (num_indices < 2) { return true; } - tensorflow::gtl::ArraySlice last = At(0); + absl::Span last = At(0); if (!IndexUtil::IndexInBounds(shape, last)) { return false; } for (int64 n = 1; n < num_indices; ++n) { - tensorflow::gtl::ArraySlice next = At(n); + absl::Span next = At(n); if (!IndexUtil::IndexInBounds(shape, next)) { return false; } diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h index f2ce22d6721ff8da46f741ccedc2a63dea5994c8..a96d483462efd77ae4761541e8c79b2c84fa49f3 100644 --- a/tensorflow/compiler/xla/sparse_index_array.h +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -20,10 +20,11 @@ limitations under the License. #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { @@ -64,7 +65,7 @@ class SparseIndexArray { SparseIndexArray(int64 max_indices, int64 rank, std::vector indices = {}); SparseIndexArray(int64 max_indices, int64 rank, - tensorflow::gtl::ArraySlice indices); + absl::Span indices); // Returns the number of elements represented by the indices stored in the // array. @@ -72,12 +73,12 @@ class SparseIndexArray { // Returns a slice that refers to the given sparse index number. The argument // must be in the range [0, element_count()). - tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; - tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + absl::Span At(int64 sparse_element_number) const; + absl::Span At(int64 sparse_element_number); // Adds the given index at the end of the array. The new size of the // SparseIndexArray must not exceed `max_indices`. - void Append(tensorflow::gtl::ArraySlice index); + void Append(absl::Span index); // Removes all indices from the array. void Clear(); @@ -95,8 +96,8 @@ class SparseIndexArray { int64 max_indices() const { return max_indices_; } // Returns a pointer to the int64 array that holds the sparse indices. - tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } - tensorflow::gtl::ArraySlice data() const { return indices_; } + absl::Span mutable_data() { return absl::MakeSpan(indices_); } + absl::Span data() const { return indices_; } // Sorts this sparse index array along with the set of corresponding values. // The indices and values are sorted in the lexicographic order of the @@ -114,7 +115,7 @@ class SparseIndexArray { // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; // template - void SortWithValues(tensorflow::gtl::MutableArraySlice values); + void SortWithValues(absl::Span values); private: std::vector indices_; @@ -123,8 +124,7 @@ class SparseIndexArray { }; template -void SparseIndexArray::SortWithValues( - tensorflow::gtl::MutableArraySlice values) { +void SparseIndexArray::SortWithValues(absl::Span values) { int64 num_elements = index_count(); CHECK_EQ(values.size(), num_elements); std::vector sort_order; @@ -139,7 +139,7 @@ void SparseIndexArray::SortWithValues( // Reorder the array elements according to sort_order. Work through the array // and follow cycles so we can do the reorder in-place. - tensorflow::gtl::InlinedVector saved_index(rank()); + absl::InlinedVector saved_index(rank()); for (int64 i = 0; i < num_elements; ++i) { // sort_order[i] == -1 indicates the element has already been copied. if (sort_order[i] < 0) { diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc index 7377f88958dcb7daf3d3f4f0e07966fdc9294580..e54057c4007078c76b79fe44d5706665e266c083 100644 --- a/tensorflow/compiler/xla/sparse_index_array_test.cc +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -33,7 +33,7 @@ TEST(SparseIndexArrayTest, Sort) { std::vector values = { 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, }; - a.SortWithValues(&values); + a.SortWithValues(absl::MakeSpan(values)); ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7, 6, 7, 8})); ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index a6b1f9004f096abb3b01d315938b0a23bea1ca48..b88fe367d7416a26c1147fd5e10fb20772814fe5 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -17,9 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stacktrace.h" @@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line, if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) { string stack_trace; if (should_log_stack_trace) { - stack_trace = - tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace()); + stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace()); } switch (log_severity) { case tensorflow::INFO: @@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() { is_done_ = true; const string& stream_str = stream_.str(); - const string str = - prior_message_handling_ == kAppendToPriorMessage - ? tensorflow::strings::StrCat(prior_message_, stream_str) - : tensorflow::strings::StrCat(stream_str, prior_message_); + const string str = prior_message_handling_ == kAppendToPriorMessage + ? absl::StrCat(prior_message_, stream_str) + : absl::StrCat(stream_str, prior_message_); if (TF_PREDICT_FALSE(str.empty())) { - return MakeError(file_, line_, code_, - tensorflow::strings::StrCat( - str, "Error without message at ", file_, ":", line_), - true /* should_log */, - tensorflow::ERROR /* log_severity */, - should_log_stack_trace_); + return MakeError( + file_, line_, code_, + absl::StrCat(str, "Error without message at ", file_, ":", line_), + true /* should_log */, tensorflow::ERROR /* log_severity */, + should_log_stack_trace_); } else { return MakeError(file_, line_, code_, str, should_log_, log_severity_, should_log_stack_trace_); diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h index 87a8c5f3a528289d47c1729ae6719aae47037c36..a657554dc2fd4fd1838639cac011bc0bb8b3d1eb 100644 --- a/tensorflow/compiler/xla/test.h +++ b/tensorflow/compiler/xla/test.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPLIER_XLA_TEST_H_ -#define TENSORFLOW_COMPLIER_XLA_TEST_H_ +#ifndef TENSORFLOW_COMPILER_XLA_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_TEST_H_ // This header includes gmock.h and enables the use of gmock matchers in tests // in third_party/tensorflow/compiler/xla. @@ -45,4 +45,4 @@ limitations under the License. #include "tensorflow/core/platform/test.h" -#endif // TENSORFLOW_COMPLIER_XLA_TEST_H_ +#endif // TENSORFLOW_COMPILER_XLA_TEST_H_ diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 8918350135fbb86973b228b35f5873fea8695b2f..3ede5e6e38a7a9e922fc0744f014c395dbd2324c 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index b8e700ae9705d596061628700036e223aca3a0f1..36b8fb26440f0f71207cc9b2af4d14f21e618cfe 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], alwayslink = True, ) @@ -68,7 +69,6 @@ cc_library( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", @@ -76,6 +76,8 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -98,6 +100,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -113,7 +118,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", @@ -127,6 +131,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) @@ -144,6 +152,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -187,7 +196,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -201,6 +209,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -274,6 +285,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -385,6 +398,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -551,6 +566,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -567,8 +584,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -591,8 +607,8 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -614,12 +630,11 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -665,6 +680,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -683,7 +699,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -691,6 +706,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -715,10 +731,8 @@ xla_test( deps = [ ":client_library_test_base", ":hlo_test_base", - "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -742,7 +756,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -750,6 +763,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -813,6 +827,7 @@ CONVOLUTION_TEST_DEPS = [ "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -824,7 +839,10 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -834,7 +852,10 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS, + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -885,6 +906,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -918,6 +940,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -994,6 +1017,10 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1067,6 +1094,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1102,7 +1130,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1120,6 +1147,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -1148,6 +1178,9 @@ xla_test_library( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1155,6 +1188,7 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], + shard_count = 20, tags = [ "enable_for_xla_interpreter", "optonly", @@ -1210,6 +1244,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1220,12 +1255,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1236,12 +1271,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1285,6 +1320,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1350,6 +1386,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1400,7 +1437,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1411,6 +1447,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1425,11 +1464,11 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1443,14 +1482,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", @@ -1460,7 +1497,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1480,6 +1517,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1540,17 +1579,16 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1615,8 +1653,8 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1629,12 +1667,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1647,7 +1686,6 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -1658,6 +1696,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1751,6 +1790,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1772,6 +1812,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -1792,6 +1833,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -1805,15 +1847,11 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1823,6 +1861,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", ], ) @@ -1830,18 +1870,12 @@ xla_test( name = "multioutput_fusion_test", srcs = ["multioutput_fusion_test.cc"], deps = [ - "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1849,6 +1883,9 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) @@ -1875,7 +1912,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1883,6 +1919,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:optional", ], ) @@ -1977,16 +2014,15 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/types:span", ], ) @@ -2009,6 +2045,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -2050,6 +2087,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -2076,6 +2114,8 @@ tf_cc_test( xla_test( name = "test_utils_test", srcs = ["test_utils_test.cc"], + # There is nothing backend specific in this test, so just pick an arbitrary backend. + backends = ["cpu"], deps = [ ":local_client_test_base", ":test_utils", @@ -2084,6 +2124,7 @@ xla_test( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", "//tensorflow/core:test", ], ) @@ -2091,19 +2132,15 @@ xla_test( xla_test( name = "iota_test", srcs = ["iota_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], + shard_count = 30, tags = [ "enable_for_xla_interpreter", + # Require optimized builds, iota_test_cpu is very slow in fastbuild. + "optonly", ], deps = [ ":client_library_test_base", - ":literal_test_util", ":xla_internal_test_main", - "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 74f2e36f826cd82ce4015df857f3de67950beaeb..0bf4556b437fb1717a9c9773834fa3031cfbd6ea 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -40,6 +41,7 @@ limitations under the License. namespace xla { namespace { + class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; @@ -293,6 +295,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); } +XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { + XlaBuilder b(TestName()); + + std::vector lhs{static_cast(0x8000000000000000ULL)}; + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + + std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + + Lt(lhs_param, rhs_param); + + ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); @@ -411,7 +429,65 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } -XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { +class IntegerDivideOpTest : public ArrayElementwiseOpTest { + protected: + template + void TestDivRem(absl::Span dividends, absl::Span divisors, + absl::Span quotients, + absl::Span remainders) { + { + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + Div(dividend, divisor); + + ComputeAndCompareR1(&builder, quotients, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + XlaBuilder builder(TestName()); + XlaOp dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + Div(dividend, ConstantR1(&builder, divisors)); + + ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); + } + + { + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + auto divisor_data = + CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); + Rem(dividend, divisor); + + ComputeAndCompareR1(&builder, remainders, + {dividend_data.get(), divisor_data.get()}); + } + + // Test with a compile-time constant divisor. + { + XlaBuilder builder(TestName()); + XlaOp dividend; + auto dividend_data = + CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); + Rem(dividend, ConstantR1(&builder, divisors)); + + ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); + } + } +}; + +XLA_TEST_F(IntegerDivideOpTest, DivS32s) { // clang-format off // Some interesting values to test. std::vector vals = { @@ -435,58 +511,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) { } } - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Div(dividend, divisor); - - ComputeAndCompareR1(&builder, quotients, - {dividend_data.get(), divisor_data.get()}); - } - - // Test with a compile-time constant divisor. - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - Div(dividend, ConstantR1(&builder, divisors)); - - ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Rem(dividend, divisor); - - ComputeAndCompareR1(&builder, remainders, - {dividend_data.get(), divisor_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); +} - // Test with a compile-time constant divisor. - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = - CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); - Rem(dividend, ConstantR1(&builder, divisors)); +XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { + std::vector dividends = {5, INT32_MIN}, divisors = {0, -1}, + quotients = {-1, INT32_MIN}, remainders = {5, 0}; - ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); } -XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { +XLA_TEST_F(IntegerDivideOpTest, DivU32s) { // clang-format off // Some interesting values to test. std::vector vals = { @@ -506,53 +541,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Div(dividend, divisor); - - ComputeAndCompareR1(&builder, quotients, - {dividend_data.get(), divisor_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - Div(dividend, ConstantR1(&builder, divisors)); - - ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); - } - - { - XlaBuilder builder(TestName()); - XlaOp dividend; - XlaOp divisor; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - auto divisor_data = - CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); - Rem(dividend, divisor); - - ComputeAndCompareR1(&builder, remainders, - {dividend_data.get(), divisor_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); +} - { - XlaBuilder builder(TestName()); - XlaOp dividend; - auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", - &builder, ÷nd); - Rem(dividend, ConstantR1(&builder, divisors)); +XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { + std::vector dividends = {5}, divisors = {0}, quotients = {-1}, + remainders = {5}; - ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); - } + TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 24b17b71007a1872462bed1f6b86ae1a5bb9922c..ac90a3adb6dbad30e3ef0b11438fb9a6fd6f8574 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -382,7 +382,7 @@ struct BatchNormTestParam { friend ::std::ostream& operator<<(::std::ostream& os, const BatchNormTestParam& p) { - os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, "; os << "feature_index=" << p.feature_index << ", "; os << "random_value_mean=" << p.random_value_mean << ", "; os << "random_value_var=" << p.random_value_var; diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 6c20f654fe3df6a28e9633cd832c11b487894bad..65589b0d6af2ffca26776541eb05a093f43e0a9a 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -65,7 +65,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { Log(x); ComputeAndCompareR0(&builder, static_cast(1.387f), {}, - error_spec_); + ErrorSpec(0.01, 0.01)); } XLA_TEST_F(Bfloat16Test, NegateScalarF16) { @@ -110,7 +110,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { {static_cast(5), static_cast(5)}) .get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 1d28e85b16596b0ec2717138fb2081878203e8b2..fe4267c73bd170f22a0456533f45e50be823a80b 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -53,10 +53,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { } } - std::unique_ptr MakeR3Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r3_shape, - Array3D* r3_array, float start, float end, int seed) { + std::unique_ptr MakeR3Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r3_shape, + Array3D* r3_array, float start, + float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( @@ -66,10 +67,11 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { return r3_global_data; } - std::unique_ptr MakeR2Data( - tensorflow::gtl::ArraySlice bounds, - tensorflow::gtl::ArraySlice minor_to_major, Shape* r2_shape, - Array2D* r2_array, float start, float end, int seed) { + std::unique_ptr MakeR2Data(absl::Span bounds, + absl::Span minor_to_major, + Shape* r2_shape, + Array2D* r2_array, float start, + float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( @@ -348,7 +350,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); - auto Each = ([&](tensorflow::gtl::ArraySlice indices, float* value) { + auto Each = ([&](absl::Span indices, float* value) { float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0], indices[1] % spec.input_bounds[1], indices[2] % spec.input_bounds[2]); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index c7b94b5bbaaa512ad36056f9e68a87cc706c24b1..74d4d2eb10c32b270a83aa04dd2e6025d7a56c26 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 59d917054be2ebe3a25f902f51972a682a5231b6..8a236db0ff2f63332892de822461dd1cc17276ca 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -17,18 +17,18 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -95,15 +95,14 @@ string ClientLibraryTestBase::TestName() const { } StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -115,7 +114,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); @@ -124,8 +123,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( StatusOr> ClientLibraryTestBase::ExecuteAndTransferReference( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { @@ -138,7 +136,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference( } string ClientLibraryTestBase::ExecuteToString( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status().ToString(); @@ -156,7 +154,7 @@ string ClientLibraryTestBase::ExecuteToString( void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); @@ -164,15 +162,14 @@ void ClientLibraryTestBase::ComputeAndCompareR1( void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout) { + absl::Span arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } void ClientLibraryTestBase::ComputeAndCompareLiteral( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); @@ -180,7 +177,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output) { // Try with no layout requirement. @@ -196,8 +193,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, tensorflow::strings::StrCat( - "Test with output layout: ", + verify_output(*actual, + absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); return Status::OK(); @@ -205,7 +202,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout) { @@ -252,13 +249,12 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // Every argument has an assigned layout. TF_ASSIGN_OR_RETURN( auto actual, - ExecuteAndTransfer( - computation, - tensorflow::gtl::ArraySlice(arguments_with_layout), - output_with_layout)); + ExecuteAndTransfer(computation, + absl::Span(arguments_with_layout), + output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { - tensorflow::strings::StrAppend(&error_message, str, " "); + absl::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); return Status::OK(); @@ -269,7 +265,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, + absl::Span arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -290,10 +286,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; - } else { - TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || - expected.shape().element_type() == PRED) - << ShapeUtil::HumanString(expected.shape()); } // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. @@ -333,8 +325,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments_passed_in, - ErrorSpec error, const Shape* shape_with_layout) { + absl::Span arguments_passed_in, ErrorSpec error, + const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), arguments_passed_in.end()); @@ -350,8 +342,6 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. @@ -391,8 +381,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } void ClientLibraryTestBase::ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::string_view expected, + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -411,7 +401,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -423,7 +413,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( void ClientLibraryTestBase::ComputeAndCompareTuple( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); if (!actual_status.ok()) { @@ -434,7 +424,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { @@ -446,8 +436,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } void ClientLibraryTestBase::ComputeAndCompare( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, - ErrorSpec error) { + XlaBuilder* builder, absl::Span arguments, ErrorSpec error) { auto status_or_data = ComputeValueAndReference(builder, arguments); EXPECT_IS_OK(status_or_data); if (!status_or_data.ok()) { @@ -460,7 +449,7 @@ void ClientLibraryTestBase::ComputeAndCompare( StatusOr, std::unique_ptr>> ClientLibraryTestBase::ComputeValueAndReference( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's // into a vector to keep the data alive on the service until the end of this // function. @@ -546,7 +535,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { - auto array = MakeUnique>(rows, cols); + auto array = absl::make_unique>(rows, cols); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f) + offset; @@ -561,7 +550,7 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, int cols_padded) { CHECK_GE(rows_padded, rows); CHECK_GE(cols_padded, cols); - auto array = MakeUnique>(rows_padded, cols_padded, 0.0); + auto array = absl::make_unique>(rows_padded, cols_padded, 0.0); for (int64 row = 0; row < rows; ++row) { for (int64 col = 0; col < cols; ++col) { (*array)(row, col) = col + (row * 1000.0f); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index b04a3b105ca017b6a91d271e603dcd0cc2068a33..22dfdfb0e4c67cc06fa748177c75cf35572196c8 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -30,14 +33,11 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -49,8 +49,8 @@ namespace xla { // use_bfloat16_params with that value. Returns the result. template std::vector ExpandUseBfloat16( - tensorflow::gtl::ArraySlice use_bfloat16_params, - tensorflow::gtl::ArraySlice specs) { + absl::Span use_bfloat16_params, + absl::Span specs) { std::vector expanded; for (bool use_bfloat16 : use_bfloat16_params) { for (const auto& spec : specs) { @@ -93,15 +93,15 @@ class ClientLibraryTestBase : public ::testing::Test { // execution options. Modify execution_options_ in your test if you want to // customize the options. StatusOr> Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, absl::Span arguments); StatusOr> ExecuteAndTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); StatusOr> ExecuteAndTransfer( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // This executes the computation via the reference client (which connects a @@ -109,13 +109,13 @@ class ClientLibraryTestBase : public ::testing::Test { // computation. StatusOr> ExecuteAndTransferReference( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_output_layout = nullptr); // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are @@ -125,102 +125,98 @@ class ClientLibraryTestBase : public ::testing::Test { // for integral types without the ErrorSpec parameter. template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span expected, + absl::Span arguments); template void ComputeAndCompareR1(XlaBuilder* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span expected, + absl::Span arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR2(XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR3(XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); template void ComputeAndCompareR4(XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout = nullptr); - void ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + const Shape* shape_with_layout = nullptr); + void ComputeAndCompareLiteral(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error, + const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const Shape* shape_with_layout = nullptr); Status ComputeAndCompareLiteralWithStatus( XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, + absl::Span arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. - void ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, - tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompareR1U8(XlaBuilder* builder, absl::string_view expected, + absl::Span arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments); + void ComputeAndCompareTuple(XlaBuilder* builder, const Literal& expected, + absl::Span arguments, + ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); void ComputeAndCompare(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); + absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); @@ -337,7 +333,7 @@ class ClientLibraryTestBase : public ::testing::Test { // converted to bfloat16. template std::unique_ptr CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array @@ -381,7 +377,7 @@ class ClientLibraryTestBase : public ::testing::Test { // actual). StatusOr, std::unique_ptr>> ComputeValueAndReference(XlaBuilder* builder, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -390,12 +386,12 @@ class ClientLibraryTestBase : public ::testing::Test { private: Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output); Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); @@ -415,7 +411,7 @@ class ClientLibraryTestBase : public ::testing::Test { template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -425,7 +421,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -440,8 +436,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -450,8 +446,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR1( - XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + XlaBuilder* builder, absl::Span expected, + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -467,7 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -477,7 +473,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -493,7 +489,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -503,7 +499,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -519,7 +515,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { std::unique_ptr expected_literal = LiteralUtil::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -529,7 +525,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + absl::Span arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -558,7 +554,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( - tensorflow::gtl::ArraySlice values, int64 parameter_number, + absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = LiteralUtil::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { @@ -613,7 +609,7 @@ template std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { - auto result = MakeUnique>(rows, cols); + auto result = absl::make_unique>(rows, cols); PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 7c52c9fbbb57f9291ea9f0966e2efa715819fb67..03d56964998f9abea21d6f82dee8faf86f9fe1d4 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,10 +38,9 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: - void ExecuteComputationR0F32( - const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, float expected_result, - bool expect_cache_hit) { + void ExecuteComputationR0F32(const XlaComputation& computation, + absl::Span arguments, + float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; std::unique_ptr result = client_ @@ -56,7 +55,7 @@ class CompilationCacheTest : public ClientLibraryTestBase { void ExecuteComputationR2F32( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5a06d061f0d83fff547502495ff8ab13fb421b70..8226b6de3f780197bc0f1145b617dba99803927f 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } @@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index b27c1044baf2c0002f166c53a81e4361c60d012a..25d10ab00af11b8ebb8147917e7cdbb21f9a42c4 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -642,5 +642,57 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { test_swap(11.24f, 5.55f); } +// Test conditional that duplicates tuple elements in the then and else +// computations. This is a regression test for b/112550242. +XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { + const Shape scalar = ShapeUtil::MakeShape(S32, {}); + const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar}); + XlaComputation then_comp; + { + XlaBuilder builder(TestName() + ".then"); + auto p = Parameter(&builder, 0, tuple2, "then.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e0}); + then_comp = builder.Build().ConsumeValueOrDie(); + } + XlaComputation else_comp; + { + XlaBuilder builder(TestName() + ".else"); + auto p = Parameter(&builder, 0, tuple2, "else.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e1}); + else_comp = builder.Build().ConsumeValueOrDie(); + } + + { + // Pred is true case. + std::vector args; + args.push_back(std::move( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), + LiteralUtil::CreateR0(-42).get()}))); + args.push_back(std::move(*LiteralUtil::CreateR0(true))); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } + { + // Pred is false case. + std::vector args; + args.push_back(std::move( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), + LiteralUtil::CreateR0(-42).get()}))); + args.push_back(std::move(*LiteralUtil::CreateR0(false))); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 1adc68cc4839dcd7d89741ec016f27bc9047c9a5..7a203d6873dbb5b69f96c50048c2c5ff3150c544 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -447,11 +448,11 @@ std::vector GetInterestingF16ConversionTestCases() { XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::vector test_cases = GetInterestingF16ConversionTestCases(); std::vector input; - c_transform(test_cases, std::back_inserter(input), - [](float f) { return Eigen::half(f); }); + absl::c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](Eigen::half h) { return static_cast(h); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, @@ -470,8 +471,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::vector input = GetInterestingF16ConversionTestCases(); std::vector expected_output; - c_transform(input, std::back_inserter(expected_output), - [](float f) { return Eigen::half(f); }); + absl::c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 7b6bbc4f571af2e11306f95c24e243e78e0f4f4e..38b6da4fa96b0f6b7ed2d56852eb3ab2872f3520 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -88,9 +88,9 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { - auto input_array = MakeUnique>(2, 3, 5, 5); + auto input_array = absl::make_unique>(2, 3, 5, 5); input_array->FillWithMultiples(0.1); - auto weight_array = MakeUnique>(4, 3, 1, 1); + auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = client_ diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 5ed8122e0073bde77bb2507a0ddd89c4365627c9..d2c6478b02423c93860244bc5eb91e652a3eac2e 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -26,16 +28,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -70,16 +70,16 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { const int kKernelSizeY = 2; const int kOutputActivationSizeZ = 256; const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); + auto alhs = absl::make_unique>( + kMiniBatchSize, kInputActivationSizeZ, kInputActivationSizeY, + kInputActivationSizeX); alhs->FillWithMultiples(static_cast(1.0f)); ASSERT_EQ(3, alhs->width()); ASSERT_EQ(3, alhs->height()); - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); + auto arhs = absl::make_unique>(kOutputActivationSizeZ, + kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); Array2D rhs_raster({ {1.0f, 0.0f}, // row 0 {0.0f, 0.0f}, // row 1 @@ -465,7 +465,7 @@ void iota_int_init_value(std::vector& values, int init_value) { } template -class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { +class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { public: void RunTest() { XlaBuilder builder(TestName()); @@ -520,8 +520,139 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { } }; -TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes); -TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); } +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x3_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x3_Valid, Types) { this->RunTest(); } + +template +class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 3, 3, 5}; + std::vector filter_dims = {3, 3, 1, 15}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/5); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(16029), static_cast(16218), static_cast(16407), + static_cast(17172), static_cast(17370), static_cast(17568), + static_cast(18369), static_cast(18576), static_cast(18783), + static_cast(19620), static_cast(19836), static_cast(20052), + static_cast(20925), static_cast(21150), static_cast(21375)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid, Types) { + this->RunTest(); +} + +template +class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + std::vector input_dims = {1, 2, 2, 6}; + std::vector filter_dims = {2, 2, 2, 12}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums, + /*feature_group_count=*/3); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = LiteralUtil::CreateR1( + {static_cast(5076), static_cast(5160), static_cast(5244), + static_cast(5328), static_cast(6164), static_cast(6264), + static_cast(6364), static_cast(6464), static_cast(7380), + static_cast(7496), static_cast(7612), static_cast(7728)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) { + this->RunTest(); +} // Test fixture to run convolution tests with and without convolution // canonicalization enabled. @@ -765,5 +896,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { std::move(*LiteralUtil::CreateFromArray(filter_data))}); } +class ConvolutionHloTest : public HloTestBase {}; + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64Forward)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[3,56,56,16] parameter(0) + %arg1 = f64[3,3,3,64] parameter(1) + ROOT %conv = f64[54,54,16,64] convolution(%arg0, %arg1), window={size=3x3}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardFilter)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %arg0 = f64[2,5,8,1] parameter(0) + %arg1 = f64[2,5,8,2] parameter(1) + ROOT %conv = f64[4,4,1,2] convolution(%arg0, %arg1), window={size=5x8 pad=1_2x1_2}, dim_labels=f01b_i01o->01bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_CPU(ConvolveF64BackwardInput)) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %output = f64[4,5,16,16] parameter(0) + %kernel = f64[5,3,7,7] parameter(1) + %reverse = f64[5,3,7,7] reverse(f64[5,3,7,7] %kernel), dimensions={2,3} + ROOT %convolution = f64[4,3,16,16] convolution(%output, %reverse), window={size=7x7 pad=3_3x3_3}, dim_labels=bf01_io01->bf01 +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 5ef273e5a26ea8a16db864974c9bfa2c296cbce8..526626c1ddd902a4ba6c608f2b9355cece9ec833 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -54,7 +54,7 @@ class CopyOpTest : public HloTestBase { void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation); + absl::Span permutation); }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { @@ -187,9 +187,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { LiteralTestUtil::ExpectR3EqualArray3D(a, *result); } -void CopyOpTest::TestCopyConstantLayoutR4( - size_t n1, size_t n2, size_t n3, size_t n4, - tensorflow::gtl::ArraySlice permutation) { +void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, + size_t n4, + absl::Span permutation) { Array4D a(n1, n2, n3, n4); for (size_t i = 0; i < n1; ++i) { for (size_t j = 0; j < n2; ++j) { diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 13c777835eb2d2519d39205cdc96f0aac4850c7d..6f7fc0e6e52a69387a4c491871b6fcd97ac638b6 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 5f234f36a8543ad408fb3430b27844beb16a54b5..86fd1ceb1368feedb14088fa7045224440f6c4f9 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { @@ -36,7 +36,7 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 2db6503afab748d7b778e26b2f9350ac64c7778b..eb15fc0593adf2d1bd84da4d0f708b6244f0fb33 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -42,7 +42,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, absl::Span arguments) { XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0e9e92ed996fbb34826d19b670c7c4920a1aad13..5873516442fa63de47360acaa353abb3a97fe881 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -261,16 +262,14 @@ string PrintDotTestParam( const ::testing::TestParamInfo& test_param) { const DotTestParam& param = test_param.param; if (param.has_addend) { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F", - param.addend_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); } else { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); } } diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 7f6f203a1ba48e0053f799c58bbbeae87aef1f7f..9bf3767ca3e229cd3eb37c1f51c526c7dd2bf0f8 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -114,14 +114,14 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, + void RunR1(absl::Span input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { // bfloat16 has explicit constructors, so it does not implicitly convert the // way built-in types do, which is why we can't take the parameter as an - // ArraySlice. We also can't convert it to a vector, because - // vector is special so that it cannot be an ArraySlice, which + // Span. We also can't convert it to a vector, because + // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = std::move(*LiteralUtil::CreateR1(input_values_int) @@ -385,10 +385,10 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values_int, - tensorflow::gtl::ArraySlice update_values_int, + void RunR1(absl::Span input_values_int, + absl::Span update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values_int) { + absl::Span expected_values_int) { Literal input_values = std::move(*LiteralUtil::CreateR1(input_values_int) ->Convert(primitive_util::NativeToPrimitiveType()) diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 39cc6c5927f1d416e31f689487efc10c20371abe..3be9657db40a7ea073baca32d8a20ccd6fa8a274 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -37,10 +37,9 @@ class FloorCeilTest : public ClientLibraryTestBase { }; // Runs a computation and comparison on expected vs f(input) - void TestR1F32(tensorflow::gtl::ArraySlice input, - tensorflow::gtl::ArraySlice expected, Function f) { - LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") - << "}"; + void TestR1F32(absl::Span input, + absl::Span expected, Function f) { + LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1(&builder, input); if (f == kCeil) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 792be0d3fcd55621b9f8cdf0fdc28f7bb49294d1..7cb2f0cedfc2e74386bb3c01ca0b838e7cdcbce9 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -22,13 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -42,14 +43,11 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -using tensorflow::gtl::ArraySlice; - namespace xla { namespace { @@ -113,7 +111,7 @@ class FusionTest : public HloTestBase { hlos[0] = builder.AddInstruction(std::move(root_hlo)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction( - ArraySlice(hlos, 0, Arity + 1), + absl::Span(hlos).subspan(0, Arity + 1), HloInstruction::FusionKind::kLoop); auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); @@ -127,12 +125,12 @@ class FusionTest : public HloTestBase { private: template - T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice xs); + T ComputeElementwiseAnswer(HloOpcode opcode, absl::Span xs); }; template <> float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -157,7 +155,7 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, template <> bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { + absl::Span xs) { switch (opcode) { case HloOpcode::kEq: return xs[0] == xs[1]; @@ -601,7 +599,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, - HloInstruction::FusionKind::kLoop); + HloInstruction::FusionKind::kInput); EXPECT_TRUE( LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index b77bece85ad1b2192b04330af9e60d3a424b59f4..6d634980449268e509d87ee064fbaaaf59abd195 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -25,17 +25,16 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class GatherOperationTest : public HloTestBase { protected: void RunTest(const string& hlo_text, Literal* operand, - Literal* gather_indices) { - RunTest(hlo_text, {operand, gather_indices}); + Literal* start_indices) { + RunTest(hlo_text, {operand, start_indices}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -52,18 +51,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -74,18 +72,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -96,18 +93,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -118,18 +115,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -140,18 +137,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) ROOT gather = s32[2,1,1,2] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -162,20 +159,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -186,20 +183,20 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -210,18 +207,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -232,18 +228,18 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) ROOT gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -254,17 +250,16 @@ ENTRY main { operand = s32[3,0] parameter(0) indices = s32[2] parameter(1) ROOT gather = s32[2,0] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 0} + slice_sizes={1, 0} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -278,19 +273,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -304,19 +299,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = u32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -330,19 +325,19 @@ ENTRY main { operand = s32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = s32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = s32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -356,19 +351,19 @@ ENTRY main { operand = u32[3,3]{1,0} parameter(0) indices = s32[6,2]{1,0} parameter(1) gather = u32[6,1,1]{2,1,0} gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1} + slice_sizes={1,1} ROOT result = u32[6]{0} reshape(gather) } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR2( + std::unique_ptr start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -379,17 +374,17 @@ ENTRY main { operand = s32[2,3,2]{2,1,0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[1,3,2]{2,1,0} gather(operand, index), - output_window_dims={0,1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0}, + offset_dims={0,1,2}, + collapsed_slice_dims={}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1,3,2} + slice_sizes={1,3,2} } )"; std::unique_ptr operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -400,16 +395,16 @@ ENTRY main { operand = s32[4]{0} parameter(0) index = s32[] parameter(1) ROOT gather = s32[] gather(operand, index), - output_window_dims={}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=0, - window_bounds={1} + slice_sizes={1} } )"; std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr gather_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -420,17 +415,17 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[0] parameter(1) ROOT gather = s32[0,3] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0}, - gather_dims_to_operand_dims={0}, + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, index_vector_dim=1, - window_bounds={1, 3} + slice_sizes={1, 3} } )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -441,11 +436,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[3,2] gather(operand, indices), - output_window_dims={0}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={0}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=1, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[3,2] broadcast(one), dimensions={} ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) @@ -453,9 +448,8 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -466,11 +460,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,3,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={1}, - gather_dims_to_operand_dims={1}, + offset_dims={1}, + collapsed_slice_dims={1}, + start_index_map={1}, index_vector_dim=2, - window_bounds={3, 1} + slice_sizes={3, 1} one = s32[] constant(1) one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) @@ -478,9 +472,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -491,11 +485,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, - window_bounds={1, 1} + slice_sizes={1, 1} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -503,9 +497,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -516,11 +510,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -530,9 +524,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, @@ -544,11 +538,11 @@ ENTRY main { operand = s32[3,3,2] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,2] gather(operand, indices), - output_window_dims={1}, - elided_window_dims={0,1}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1}, + collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1,2} + slice_sizes={1,1,2} one = s32[] constant(1) one_broadcasted = s32[2,2] broadcast(one), dimensions={} ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) @@ -558,9 +552,9 @@ ENTRY main { LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -571,11 +565,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2] parameter(1) gather = s32[1,1] gather(operand, indices), - output_window_dims={0,1}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={0,1}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[1,1] broadcast(one), dimensions={} ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) @@ -583,9 +577,8 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = - LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), start_indices.get()); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -596,11 +589,11 @@ ENTRY main { operand = s32[3,3] parameter(0) indices = s32[2,2] parameter(1) gather = s32[2,1,1] gather(operand, indices), - output_window_dims={1,2}, - elided_window_dims={}, - gather_dims_to_operand_dims={0,1}, + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0,1}, index_vector_dim=0, - window_bounds={1,1} + slice_sizes={1,1} one = s32[] constant(1) one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) @@ -608,9 +601,9 @@ ENTRY main { )"; std::unique_ptr operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr gather_indices = + std::unique_ptr start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), gather_indices.get()); + RunTest(hlo_text, operand.get(), start_indices.get()); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -622,11 +615,11 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // operand = s32[3,3] parameter(0) // indices = s32[2] parameter(1) // ROOT gather = s32[2,3] gather(operand, indices), - // output_window_dims={1}, - // elided_window_dims={0}, - // gather_dims_to_operand_dims={0}, + // offset_dims={1}, + // collapsed_slice_dims={0}, + // start_index_map={0}, // index_vector_dim=1, - // window_bounds={1, 3} + // slice_sizes={1, 3} // } XlaBuilder builder("gather_basic"); @@ -637,9 +630,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { auto operand = Parameter(&builder, 0, operand_shape, "operand"); auto indices = Parameter(&builder, 1, indices_shape, "indices"); GatherDimensionNumbers dim_numbers; - dim_numbers.add_output_window_dims(1); - dim_numbers.add_elided_window_dims(0); - dim_numbers.add_gather_dims_to_operand_dims(0); + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(0); + dim_numbers.add_start_index_map(0); dim_numbers.set_index_vector_dim(1); Gather(operand, indices, dim_numbers, {1, 3}); diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index 51450314b611b49c643fb6fd5b0c0d2e7205a2d2..1115e50fe3120b7dbd891f07dedcacefa5ecf3ea 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -126,9 +126,8 @@ INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, &IsFinite})); -using BinaryBuildFuncTy = - std::function)>; +using BinaryBuildFuncTy = std::function)>; struct BinaryOpTestParam { std::function compute_func; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index f05d1a8b9d372e720ae1634a9c8d5c0591e39b89..fc4c68246e62a4baa7a506ec37886102c35c4b3b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -20,17 +20,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,9 +42,8 @@ namespace xla { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::optional; +using absl::optional; +using absl::string_view; constexpr char kInterpreter[] = "interpreter"; @@ -83,24 +85,42 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) +HloTestBase::HloTestBase(bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = - MakeUnique(allow_mixed_precision_in_hlo_verifier); + hlo_verifier_ = absl::make_unique( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); } -/* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, GetModuleConfigForTest()); + return absl::make_unique(name, GetModuleConfigForTest()); +} + +/* static */ +StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, + HloModule* module) { + const string module_str_before_run = module->ToProto().ShortDebugString(); + const auto status_or = hlo_pass->Run(module); + if (status_or.status().ok()) { + const string module_str_after_run = module->ToProto().ShortDebugString(); + if (!status_or.ValueOrDie()) { + // Check that the proto remains same. + EXPECT_EQ(module_str_after_run, module_str_before_run); + } + } + return status_or; } -/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { +DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); @@ -109,14 +129,12 @@ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { } StatusOr> HloTestBase::Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) @@ -124,8 +142,7 @@ std::unique_ptr HloTestBase::ExecuteNoHloPasses( } std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + std::unique_ptr module, absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -148,7 +165,8 @@ StatusOr> HloTestBase::MakeReferenceModule( } StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, bool run_hlo_passes, const std::function& reference_preprocessor) { TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); @@ -167,7 +185,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -180,7 +199,8 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const ArraySlice arguments, + std::unique_ptr module, + const absl::Span arguments, const optional& error, const std::function& reference_preprocessor) { auto result = @@ -199,7 +219,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr& literal) { return literal.get(); }); @@ -213,7 +233,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( const auto& fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr& literal) { return literal.get(); }); @@ -222,8 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompare( - const StringPiece hlo_string, - const tensorflow::gtl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -236,7 +255,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -248,7 +267,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( MakeFakeArguments(module_or_status.ValueOrDie().get()) .ConsumeValueOrDie(); std::vector fake_argument_ptrs; - c_transform( + absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const std::unique_ptr& literal) { return literal.get(); }); return test_runner_ @@ -260,7 +279,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); @@ -273,8 +292,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - const StringPiece hlo_string, - const tensorflow::gtl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -288,7 +306,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); @@ -301,10 +319,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } HloComputation* HloTestBase::FindComputation(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { auto computations = module->computations(); - auto it = c_find_if(computations, - [&](HloComputation* c) { return c->name() == name; }); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); if (it == computations.end()) { return nullptr; } @@ -312,11 +330,11 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, } HloInstruction* HloTestBase::FindInstruction(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); - auto it = c_find_if(instructions, - [&](HloInstruction* i) { return i->name() == name; }); + auto it = absl::c_find_if( + instructions, [&](HloInstruction* i) { return i->name() == name; }); if (it != instructions.end()) { return *it; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4232eeceb10b37a209f247ffa70fb9a08be337e6..4c88257bb27f5504588bba3ee0b14ac53c971225 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -72,20 +72,27 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); + std::unique_ptr CreateNewModule(const string& name = TestName()); + + // Runs the hlo_pass with the provided module and returns the result. This + // function also verifies that the module remains unchanged when hlo_pass + // returns false as the StatusOr value. + static StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); + HloTestBase(bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive = false, bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} @@ -93,10 +100,13 @@ class HloTestBase : public ::testing::Test { // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. - static DebugOptions GetDebugOptionsForTest(); + // + // This function is virtual so tests can specify an alternative set of debug + // options (e.g. disabling additional passes). + virtual DebugOptions GetDebugOptionsForTest(); // Gets an HloModuleConfig with options appropriate for tests. - static HloModuleConfig GetModuleConfigForTest() { + HloModuleConfig GetModuleConfigForTest() { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); return config; @@ -104,18 +114,15 @@ class HloTestBase : public ::testing::Test { // Executes the given module and return the result as a Literal. StatusOr> Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments); + std::unique_ptr module, absl::Span arguments); // Executes the given hlo module on two backends and compares results. // @@ -130,8 +137,8 @@ class HloTestBase : public ::testing::Test { // modified. ::testing::AssertionResult RunAndCompare( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, + const absl::Span arguments, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -139,23 +146,21 @@ class HloTestBase : public ::testing::Test { // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, + const absl::Span arguments, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Executes an hlo module with fake inputs and compares the results. ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, - const tensorflow::gtl::optional& error, + std::unique_ptr module, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; // Same as above, except that the module will be executed without Hlo // optimization. ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, - const tensorflow::gtl::optional& error, + std::unique_ptr module, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -163,23 +168,23 @@ class HloTestBase : public ::testing::Test { // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. ::testing::AssertionResult RunAndCompare( - const tensorflow::StringPiece hlo_string, - const tensorflow::gtl::optional& error, + const absl::string_view hlo_string, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + ::testing::AssertionResult Run(const absl::string_view hlo_string) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPasses( - const tensorflow::StringPiece hlo_string, - const tensorflow::gtl::optional& error, + const absl::string_view hlo_string, + const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( - const string& filename, const tensorflow::gtl::optional& error, + const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -222,10 +227,8 @@ class HloTestBase : public ::testing::Test { // // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. - HloComputation* FindComputation(HloModule* module, - tensorflow::StringPiece name); - HloInstruction* FindInstruction(HloModule* module, - tensorflow::StringPiece name); + HloComputation* FindComputation(HloModule* module, absl::string_view name); + HloInstruction* FindInstruction(HloModule* module, absl::string_view name); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } @@ -255,8 +258,8 @@ class HloTestBase : public ::testing::Test { // error happens before the results are computed, returns the error status. StatusOr<::testing::AssertionResult> RunAndCompareInternal( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - const tensorflow::gtl::optional& error, bool run_hlo_passes, + const absl::Span arguments, + const absl::optional& error, bool run_hlo_passes, const std::function& reference_preprocessor); }; diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b..8f86c528d0f346b0264948d592660911880f96d1 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -24,8 +25,11 @@ limitations under the License. namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(MakeUnique()) {} +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -50,8 +54,7 @@ void HloVerifiedTestBase::TearDown() { } void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr mutated = verifier.Run(module); + xla::StatusOr mutated = verifier().Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -72,7 +75,7 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, +void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b28c01c369fa1ae1c7941f5c8139882c4dbed08..8fbc4fa753ebf0c02b44ce10edf9251d28113f98 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -29,7 +29,8 @@ namespace xla { // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { protected: - HloVerifiedTestBase(); + explicit HloVerifiedTestBase(bool layout_sensitive = false, + bool allow_mixed_precision = false); ~HloVerifiedTestBase() override; // Constructs a default shape verifier. @@ -44,32 +45,28 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); - // Sets the shape-size function used during hlo verification. If this isn't - // called, a default ShapeVerifier is used instead. - void SetShapeVerifier(std::unique_ptr shape_verifier) { - shape_verifier_ = std::move(shape_verifier); - } - // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. HloModule* CreateNewModule(const string& name = TestName()); + private: + void VerifyModule(HloModule* module); + // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. - private: + // // Lazily populated. Access via module(). std::unique_ptr module_; // Populated by calls to CreateNewModule. std::vector> modules_; - std::unique_ptr shape_verifier_; + bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 17ac95ae0198d98490b25f7f2edd32d1e0495803..310f3495922250d68aa463fcbb24ef0b04603d09 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -23,40 +23,95 @@ limitations under the License. namespace xla { namespace { -class IotaTest : public ClientLibraryTestBase { - public: - explicit IotaTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) {} - template - std::vector GetExpected(const int64 num_elements) { - std::vector result(num_elements); - std::iota(result.begin(), result.end(), 0); - return result; +template +std::vector GetR1Expected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; +} + +class IotaR1Test + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(IotaR1Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + Iota(&builder, element_type, num_elements); + if (element_type == F32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), {}, + ErrorSpec{0.0001}); + } else if (element_type == U32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); + } else { + CHECK_EQ(element_type, S32); + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); } -}; - -XLA_TEST_F(IotaTest, SimpleR1) { - for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { - { - XlaBuilder builder(TestName() + "_f32"); - IotaGen(&builder, F32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, - ErrorSpec{0.0001}); - } - { - XlaBuilder builder(TestName() + "_u32"); - IotaGen(&builder, U32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } - { - XlaBuilder builder(TestName() + "_s32"); - IotaGen(&builder, S32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } +} + +INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test, + ::testing::Combine(::testing::Values(F32, U32, S32), + ::testing::Range(/*start=*/10, + /*end=*/10001, + /*step=*/10))); + +class IotaR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR2Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); } } +INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1))); + +class IotaR3Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR3Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1, 2))); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index cde1dcd9cd10c86107f495a92be42b57bf6a085b..554eb24d44168caa7d7252015e3d99f2d567df9b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -35,8 +35,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), - now_usec, name.c_str())); + absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name)); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; @@ -94,7 +93,7 @@ void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( const LiteralSlice& expected, const LiteralSlice& actual, - const tensorflow::gtl::optional& error) { + const absl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; return StatusToAssertion(literal_comparison::Near( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 31a099c15f1f20457c90de97054f68a31eb49011..96f72212f35f5e6e98e2dc24fd9a87891a326e8f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/types/optional.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -32,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -62,7 +62,7 @@ class LiteralTestUtil { static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); template - static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Equal(absl::Span expected, const LiteralSlice& actual); template static void ExpectR2Equal( @@ -102,7 +102,7 @@ class LiteralTestUtil { const ErrorSpec& error); template - static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, + static void ExpectR1Near(absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error); template @@ -146,7 +146,7 @@ class LiteralTestUtil { // will be compared recursively. static ::testing::AssertionResult NearOrEqual( const LiteralSlice& expected, const LiteralSlice& actual, - const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + const absl::optional& error) TF_MUST_USE_RESULT; private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); @@ -160,7 +160,7 @@ template template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + absl::Span expected, const LiteralSlice& actual) { EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); } @@ -206,7 +206,7 @@ template template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, + absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); } diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index f297b2b847f570d26e71ddcd8e34bc626f982e1f..4151bfae0332ffc706ba730d181c487eabab856f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +80,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { std::vector results; TF_CHECK_OK(env->GetMatchingPaths(pattern, &results)); - LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; + LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { LiteralProto literal_proto; @@ -105,8 +105,10 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); - EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); - EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index e719da54d45d3e6eb3f3e14d3fa3076db2081e04..8d658695576035cdc34a213847460dd80de5f67e 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" @@ -125,7 +126,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), config); + return absl::make_unique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 6fc11150978931f980349799372872f9fb68f292..0487d314094edcab61a92de32f14113dd19673fa 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -51,8 +51,9 @@ void LlvmIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status()); + Status status = CompileToExecutable(std::move(hlo_module)).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); TF_ASSERT_OK(filecheck_result.status()); @@ -73,9 +74,10 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - TF_ASSERT_OK( - CompileToAotCompilationResult(std::move(hlo_module), options).status()); + Status status = + CompileToAotCompilationResult(std::move(hlo_module), options).status(); ResetIrHook(); + TF_ASSERT_OK(status); StatusOr filecheck_result = RunFileCheck(ir_, pattern); ASSERT_TRUE(filecheck_result.ok()); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index e2cd5bcc5a95f692dcf4a43d717252bfe876aa81..237a4a361e386e24c2897c42602eb60ca7234731 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -53,7 +53,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // deallocation happen on the right allocator. ExecutableRunOptions options; options.set_allocator(allocator); - tensorflow::gtl::optional result = + absl::optional result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), options); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index eaddf756dbc913dd9668cd22228fbd18c2c33309..a8c68fc7fdbad30068af44606f559ca96603fe66 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -156,7 +156,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -164,7 +164,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { return ExecuteLocally(computation, arguments, build_options, run_options) @@ -173,14 +173,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments) { + absl::Span arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } StatusOr LocalClientTestBase::ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index b4477e9a6b23363ee3a1380f9f98f4b8226f6920..90095c5d410f1561a1303a0f62f44d22ed5340f9 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -93,19 +93,19 @@ class LocalClientTestBase : public ::testing::Test { // options. StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); StatusOr ExecuteLocally( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments); + absl::Span arguments); ScopedShapedBuffer ExecuteLocallyOrDie( const XlaComputation& computation, - tensorflow::gtl::ArraySlice arguments, + absl::Span arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index da8c42d465340f2af3d6acd2c3676b69512f193f..edb592f43ec778a3fe6e5ef936827dd612791760 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -17,12 +17,14 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -133,10 +134,9 @@ class TestLinspaceMaxParametric float from = -128.0, to = 256.0; std::unique_ptr> alhs = MakeLinspaceArray2D(from, to, rows, cols); - auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); + auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); - XlaBuilder builder( - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + XlaBuilder builder(absl::StrFormat("max_%dx%d_linspace", rows, cols)); auto lhs = ConstantR2FromArray2D(&builder, *alhs); auto rhs = ConstantR2FromArray2D(&builder, *arhs); Max(lhs, rhs); @@ -158,7 +158,7 @@ class TestLinspaceMaxParametric string PrintTestLinspaceMaxParam( const ::testing::TestParamInfo& test_param) { const TestLinspaceMaxParam& param = test_param.param; - return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); + return absl::StrCat(param.rows, "r", param.cols, "c"); } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index eb06b115daa96bccd73de30bb7fa30733a6fd947..05f90ba9fb7d781f64bd52008423f603397ce628 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -36,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -46,18 +47,27 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::gtl::ArraySlice; class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } + // Layout assignment assumes that there are no fusions in the input graph. + // Since the purpose of this test is to send pre-fused graphs to XLA, we have + // to do layout assignment ourselves. + DebugOptions GetDebugOptionsForTest() override { + auto opts = HloTestBase::GetDebugOptionsForTest(); + opts.add_xla_disable_hlo_passes("layout-assignment"); + return opts; + } + void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); - const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + const Shape elem_shape2 = + ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); @@ -85,8 +95,8 @@ class MultiOutputFusionTest : public HloTestBase { auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub, add2}, 0, 2))); + auto tuple = + computation->AddInstruction(HloInstruction::CreateTuple({sub, add2})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape2, tuple, 0)); auto gte1 = computation->AddInstruction( @@ -100,10 +110,10 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue(2.5f); - Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer(std::move(hlo_module), @@ -115,8 +125,10 @@ class MultiOutputFusionTest : public HloTestBase { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); - const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + const Shape elem_shape_F32 = + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); + const Shape elem_shape_U8 = + ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( @@ -136,17 +148,18 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {size, 1}), add)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, + dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { - auto tuple = computation->AddInstruction(HloInstruction::CreateTuple( - ArraySlice({sub_U8, add}, 0, 2))); + auto tuple = computation->AddInstruction( + HloInstruction::CreateTuple({sub_U8, add})); auto gte0 = computation->AddInstruction( HloInstruction::CreateGetTupleElement(elem_shape_U8, tuple, 0)); @@ -161,9 +174,9 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0(ShapeUtil::MakeShape(F32, {size})); + Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); - Literal input1(ShapeUtil::MakeShape(F64, {size})); + Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); Literal expect = @@ -291,7 +304,7 @@ const char* const kScalarOps = R"( XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -323,7 +336,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -355,7 +368,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -388,7 +401,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -422,7 +435,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -457,7 +470,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -494,7 +507,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) init1 = f32[] parameter(1) @@ -529,7 +542,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { p0 = f16[2,2,2]{2,1,0} parameter(0) convert = f32[2,2,2]{2,1,0} convert(p0) diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index ca21b0b2ba590a6daadf2c8d3d9ad213514b0f0f..cbeddffacfa4a0fc560e8b9f9a8d7bd23ff32e55 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -16,12 +16,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -140,7 +140,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1.0f, 2.0f}, // row 0 {3.0f, 4.0f}, // row 1 @@ -151,7 +151,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(1.5); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 0, 0, 1) = 2.0f; @@ -171,7 +171,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { AddParam(*LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(8, 5, 1, 1); + auto expected = absl::make_unique>(8, 5, 1, 1); expected->Fill(pad_value); (*expected)(1, 0, 0, 0) = 1.0f; (*expected)(1, 2, 0, 0) = 2.0f; @@ -269,7 +269,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { XLA_TEST_F(PadTest, Pad4DU8Array) { XlaBuilder b(TestName()); - auto input = MakeUnique>(1, 1, 3, 2); + auto input = absl::make_unique>(1, 1, 3, 2); Array2D input_xy({ {1, 2}, // row 0 {3, 4}, // row 1 @@ -280,7 +280,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { Pad(AddParam(*input, &b), ConstantR0(&b, 35), r4_padding_on_dim0_dim1_); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(35); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 2; @@ -301,13 +301,13 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { Pad(input, ConstantR0(&b, false), r4_padding_on_dim0_dim1_); // For the same reason, use Select to convert boolean values to int32. - auto zeros = MakeUnique>(2, 3, 3, 2); - auto ones = MakeUnique>(2, 3, 3, 2); + auto zeros = absl::make_unique>(2, 3, 3, 2); + auto ones = absl::make_unique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); - auto expected = MakeUnique>(2, 3, 3, 2); + auto expected = absl::make_unique>(2, 3, 3, 2); expected->Fill(0); (*expected)(1, 0, 0, 0) = 1; (*expected)(1, 0, 0, 1) = 1; @@ -321,7 +321,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { XLA_TEST_P(PadTestFloat, Large2DPad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(4, 4); + auto ones = absl::make_unique>(4, 4); ones->Fill(1.0f); auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -342,7 +342,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(0.0f); auto input = AddParam(*operand, &b); @@ -368,7 +368,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -395,7 +395,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -423,7 +423,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto operand = MakeUnique>(in_rows, in_cols); + auto operand = absl::make_unique>(in_rows, in_cols); operand->FillUnique(1.0f); auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); @@ -446,7 +446,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { // Regression test for b/31827337. XLA_TEST_P(PadTestFloat, ReducePad) { XlaBuilder b(TestName()); - auto ones = MakeUnique>(2, 2, 2, 2); + auto ones = absl::make_unique>(2, 2, 2, 2); ones->Fill(1.0); auto input = AddParam(*ones, &b); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 2fc7f816b56db6f57ca835d1847476b6d622ce5e..58539e6b061b0cec1cc660b52e78894e5deeea56 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -31,7 +31,7 @@ class PredTest : public ClientLibraryTestBase { protected: void TestCompare(bool lhs, bool rhs, bool expected, std::function)> + absl::Span)> op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 326e13b3867f2f804e882e00e35850d0189ad8d7..5f322b768d8620cb64a79bb8fca5fecf282f28f5 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -37,8 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, - tensorflow::gtl::ArraySlice dims, + std::unique_ptr UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution @@ -50,8 +49,9 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest( - T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { +std::unique_ptr PrngTest::UniformTest(T a, T b, + absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -61,7 +61,7 @@ std::unique_ptr PrngTest::UniformTest( auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + actual->EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -117,7 +117,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); result->Literal::EachCell( - [&](tensorflow::gtl::ArraySlice, bfloat16 value) { + [&](absl::Span, bfloat16 value) { int64 index = static_cast((value - low) / interval); counts[index]++; }); @@ -149,8 +149,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell([&counts](tensorflow::gtl::ArraySlice, - int32 value) { ++counts[value]; }); + actual->EachCell( + [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index a080dd1732bde21712cf47b4b57538cf4040f30e..9af9ea4a2229bb6ca7c3561350f11837f5072a2c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -29,16 +29,13 @@ limitations under the License. namespace xla { namespace { -namespace str_util = tensorflow::str_util; -namespace strings = tensorflow::strings; - struct ReduceLayout { std::array input_minor_to_major; std::array output_minor_to_major; string ToString() const { - return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", - str_util::Join(output_minor_to_major, "x")); + return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_", + absl::StrJoin(output_minor_to_major, "x")); } }; diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 531648fe3eb8e3941c5e3c012847ee68c616590f..0916a07f4fa99af6cf25441fa8558a558bfa032f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10}; string TestDataToString(const ::testing::TestParamInfo data) { int i = data.param; - return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", - mantissa_sizes[i], "_mantissa_bits"); + return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i], + "_mantissa_bits"); } // The FPVAL macro allows us to write out the binary representation of the diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2065271a7f686c52c88df80b0efe8f2e1542d198..8c62adea231d1d3197c6e483d58008b1577b156d 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -51,7 +54,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -113,8 +115,7 @@ class ReduceTest : public ClientLibraryTestBase { ErrorSpec(0.001)); } - void RunR1ToR0PredTest(bool and_reduce, - tensorflow::gtl::ArraySlice input_data) { + void RunR1ToR0PredTest(bool and_reduce, absl::Span input_data) { const int element_count = input_data.size(); XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); @@ -259,8 +260,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments, ErrorSpec(0.01, 1e-4)); } @@ -269,8 +270,8 @@ class ReduceTest : public ClientLibraryTestBase { void ComputeAndCompareGeneric( typename std::enable_if::value, XlaBuilder>::type* builder, - tensorflow::gtl::ArraySlice expected, - tensorflow::gtl::ArraySlice arguments) { + absl::Span expected, + absl::Span arguments) { ComputeAndCompareR1(builder, expected, arguments); } @@ -302,7 +303,7 @@ class ReduceTest : public ClientLibraryTestBase { client_->TransferToServer(*input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to - // ArraySlice. + // Span. std::unique_ptr expected(new NativeT[cols]); for (int64 colno = 0; colno < cols; ++colno) { NativeT column_result = initial_value; @@ -314,7 +315,7 @@ class ReduceTest : public ClientLibraryTestBase { } ComputeAndCompareGeneric( - &builder, tensorflow::gtl::ArraySlice(expected.get(), cols), + &builder, absl::Span(expected.get(), cols), {input_global_data.get()}); } @@ -556,12 +557,11 @@ struct BoundsLayout { }; void PrintTo(const BoundsLayout& spec, std::ostream* os) { - *os << tensorflow::strings::Printf( - "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), - spec.bounds.size() - spec.reduce_dims.size(), - tensorflow::str_util::Join(spec.bounds, "x").c_str(), - tensorflow::str_util::Join(spec.layout, "").c_str(), - tensorflow::str_util::Join(spec.reduce_dims, "").c_str()); + *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(), + spec.bounds.size() - spec.reduce_dims.size(), + absl::StrJoin(spec.bounds, "x"), + absl::StrJoin(spec.layout, ""), + absl::StrJoin(spec.reduce_dims, "")); } // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 92c93f08b2e8e543aeaa58020eddacd109b2e2da..997880a018a264de7b0623d27997defdfc68f14a 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -18,6 +18,10 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -35,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -54,7 +57,7 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { if (use_bfloat16()) { - return ErrorSpec(1e-1, 5e-2); + return ErrorSpec(2e-1, 6e-2); } else { return ErrorSpec(1e-3, 1e-3); } @@ -67,8 +70,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } void ReduceWindowAdd(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_); @@ -78,8 +81,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMax(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_); @@ -89,8 +92,8 @@ class ReduceWindowTest : public ::testing::WithParamInterface, } void ReduceWindowMin(const XlaOp& input, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding) { auto init = CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_); @@ -357,7 +360,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = MakeUnique(shape); + auto arg_literal = absl::make_unique(shape); arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); @@ -368,7 +371,7 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = MakeUnique(result_shape); + auto expected = absl::make_unique(result_shape); expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -578,21 +581,20 @@ string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // - "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), // + "__pad_high_", absl::StrJoin(param.pad_high, "x"), // + "__layout_", absl::StrJoin(param.layout, "_"), // (param.reducer == kAdd) ? "_add" : "_max"); CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -934,15 +936,15 @@ string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__padding_", param.padding == Padding::kSame ? "same" : "valid", - "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", + absl::StrJoin(param.window_bounds, "x"), "__strides_", + absl::StrJoin(param.strides, "x"), "__padding_", + param.padding == Padding::kSame ? "same" : "valid", "__layout_", + param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", + param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1068,17 +1070,16 @@ string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__layout_", param.layout[0], "_", param.layout[1], // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", + absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", + param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1261,21 +1262,27 @@ struct R1ReduceWindowTestData { /*pad_low=*/{5}, /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4095}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = + absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), + "__strides_", absl::StrJoin(param.strides, "x"), + "__pad_low_", absl::StrJoin(param.pad_low, "x"), + "__pad_high_", absl::StrJoin(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1296,7 +1303,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = - LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); + LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", &b, ¶meter); @@ -1320,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; auto expected = ReferenceUtil::ReduceWindow1DGeneric( - /*operand=*/tensorflow::gtl::ArraySlice(input_vector), + /*operand=*/absl::Span(input_vector), /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, @@ -1442,7 +1449,7 @@ ENTRY reduce-window-identity { } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } XLA_TEST_F(HloTestBase, ReduceWindowS32) { @@ -1461,7 +1468,7 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } XLA_TEST_F(HloTestBase, ReduceWindowF16) { @@ -1480,7 +1487,7 @@ ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); + EXPECT_TRUE(RunAndCompare(hlo_string, absl::nullopt)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index 368f5583c9ce3773e57b858ff7606f679346529a..ae24eb5eb4822a2057e34a1aec8b7d64604d8984 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 382d1b1ae741285dcd1f7761edb82a5c333887af..17d12715f60f624c35169048121ca139d78a544f 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -689,9 +689,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -711,9 +710,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -734,9 +732,8 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -747,7 +744,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); - input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { + input.Each([&](absl::Span indices, float* cell) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); @@ -762,7 +759,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); input_array.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, + [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( @@ -842,9 +839,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::vector bounds = {2, 2, 2, 2}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -871,9 +867,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::vector bounds = {1, 1, 250, 300}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -900,9 +895,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::vector bounds = {5, 5, 1, 10}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -930,9 +924,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::vector bounds = {5, 5, 10, 1}; std::vector new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); @@ -959,9 +952,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::vector bounds = {3, 3, 1, 3}; std::vector new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]}; Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); - input.Each( - [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, - float* cell) { *cell = distribution(rng); }); + input.Each([&rng, &distribution](absl::Span /* indices */, + float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 41e49b4003236d55d85592315652a0ddefd5c485..74ded82ddfae10c21fe98ec2e250b4eaecf95222 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -37,16 +39,14 @@ static std::array use_bfloat16_params{false}; #endif struct ReverseSpec { - tensorflow::gtl::ArraySlice input_dims; - tensorflow::gtl::ArraySlice reversal; + absl::Span input_dims; + absl::Span reversal; bool use_bfloat16; string ToTestCaseName() const { - return tensorflow::strings::Printf( - "reverse_%s_in_dims_%s_%s", - tensorflow::str_util::Join(input_dims, "x").c_str(), - tensorflow::str_util::Join(reversal, "x").c_str(), - use_bfloat16 ? "bf16" : "f32"); + return absl::StrFormat( + "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"), + absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32"); } }; @@ -91,17 +91,16 @@ TEST_P(FloatReverseTest, Reverses) { std::unique_ptr expected = input_literal->CloneToUnique(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float) { - for (int64 i = 0; i < indices.size(); ++i) { - output_indices[i] = indices[i]; - } - float value = input_literal->Get(indices); - for (int64 dim : spec.reversal) { - output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; - } - expected->Set(output_indices, value); - }); + expected->EachCell([&](absl::Span indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal->Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected->Set(output_indices, value); + }); ComputeAndCompareLiteral(&builder, *expected, {}); } diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index a620fe19085d98c8b6642b25b159d6c2308bdae2..e692b8c5d5e661587bac16a2992e35f92c4c0bd9 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -47,8 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase { TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { string data(sizeof(float) * 2, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 2); + absl::Span floats(tensorflow::bit_cast(data.data()), 2); floats[0] = 42.0; floats[1] = 24.0; @@ -70,8 +69,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With x as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=0,x=1 @@ -105,8 +103,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { string data(sizeof(float) * 4, 0); - tensorflow::gtl::MutableArraySlice floats( - tensorflow::bit_cast(data.data()), 4); + absl::Span floats(tensorflow::bit_cast(data.data()), 4); // With y as the minor dimension, these will become: floats[0] = 42.0; // y=0,x=0 floats[1] = 24.0; // y=1,x=0 diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc index b4f2b74e3dc9e80f50454b28eb6f2502cef3e681..2b03a0b0b22eb0ae4777417f6640c5f90171d808 100644 --- a/tensorflow/compiler/xla/tests/sample_text_test.cc +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -19,18 +19,18 @@ limitations under the License. #include #include +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class SampleTextTest : public HloTestBase {}; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc..07460a7e01a5497aa6411ddb6866dddfc70f2068 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -46,9 +46,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { // A template for building and running a binary comparison test. template void TestCompare(NativeT lhs, NativeT rhs, bool expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); @@ -58,9 +57,8 @@ class ScalarComputationsTest : public ClientLibraryTestBase { template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - std::function)> - op) { + const std::function)>& op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0(&builder, lhs); XlaOp rhs_op = ConstantR0(&builder, rhs); diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 769d07e9d0cbacce656b9327c13417a13976e3d8..1858dcea61241a2aeee11592a9b09f200763b25a 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -23,7 +23,7 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class ScatterTest : public HloTestBase { protected: @@ -32,8 +32,7 @@ class ScatterTest : public HloTestBase { RunTest(hlo_text, {operand, scatter_indices, updates}); } - void RunTest(const string& hlo_text, - tensorflow::gtl::ArraySlice args) { + void RunTest(const string& hlo_text, absl::Span args) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -594,21 +593,20 @@ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { } ENTRY main { - operand = s32[3,3] parameter(0) + operand = s32[3] parameter(0) indices = s32[0] parameter(1) - updates = s32[0,0] parameter(2) - ROOT scatter = s32[3,3] scatter(operand, indices, updates), + updates = s32[0] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), to_apply=update_s32, - update_window_dims={1}, + update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1 } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}}); + std::unique_ptr updates = LiteralUtil::CreateR1({}); RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); } diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index e3d4f98dd7432d1dce7e697586e8b17105dc82e7..f737b5158b3622d677aea5bf64a421a56e2c42dd 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -42,8 +42,8 @@ struct SelectAndScatterTestParam { std::vector operand_shape; std::vector source_shape; Padding padding_type; - tensorflow::gtl::ArraySlice window_dimensions; - tensorflow::gtl::ArraySlice window_strides; + absl::Span window_dimensions; + absl::Span window_strides; }; class SelectAndScatterTest diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index b8ad6668f80a3002eff3cc458997966ee67c8d4b..c9a58aefb4acc066c10e98aea46375523cf554d0 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -18,6 +18,11 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -25,16 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using ::tensorflow::str_util::Join; - class SliceTest : public ClientLibraryTestBase {}; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { @@ -193,9 +194,9 @@ class SliceR1Test : public ClientLibraryTestBase, protected: template void Run(const R1Spec& spec) { - // This can't be an std::vector, since you can't grab an ArraySlice of a + // This can't be an std::vector, since you can't grab a Span of a // vector. - tensorflow::gtl::InlinedVector input(spec.input_dim0); + absl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); auto literal = LiteralUtil::CreateR1(input); @@ -205,7 +206,7 @@ class SliceR1Test : public ClientLibraryTestBase, {spec.slice_stride}); // Ditto. - tensorflow::gtl::InlinedVector expected; + absl::InlinedVector expected; for (int i = spec.slice_start; i < spec.slice_limit; i += spec.slice_stride) { expected.push_back(i); @@ -222,9 +223,8 @@ class SliceR1LargeTest : public SliceR1Test {}; string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; - return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, - spec.slice_start, spec.slice_limit, - spec.slice_stride); + return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start, + spec.slice_limit, spec.slice_stride); } XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } @@ -448,13 +448,11 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo& data) { const R4Spec& spec = data.param; - return tensorflow::strings::StrCat( // - "input_", Join(spec.input_dims, "x"), // - "__layout_", Join(spec.input_layout, ""), // - "__starts_", Join(spec.slice_starts, "x"), // - "__limits_", Join(spec.slice_limits, "x"), // - "__strides_", Join(spec.slice_strides, "x") // - ); + return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"), + "__layout_", absl::StrJoin(spec.input_layout, ""), + "__starts_", absl::StrJoin(spec.slice_starts, "x"), + "__limits_", absl::StrJoin(spec.slice_limits, "x"), + "__strides_", absl::StrJoin(spec.slice_strides, "x")); } class SliceR4Test : public ClientLibraryTestBase, diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index be35ec6c6ee4c015755622b2dc9bb92e23af7c85..a9874a918659f1d7403ba0c5cb968e62d7091936 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" @@ -44,7 +46,7 @@ ManifestT ReadManifest() { string contents((std::istreambuf_iterator(file_stream)), std::istreambuf_iterator()); - std::vector lines = tensorflow::str_util::Split(contents, '\n'); + std::vector lines = absl::StrSplit(contents, '\n'); for (string& line : lines) { auto comment = line.find("//"); if (comment != string::npos) { @@ -53,8 +55,8 @@ ManifestT ReadManifest() { if (line.empty()) { continue; } - tensorflow::str_util::StripTrailingWhitespace(&line); - std::vector pieces = tensorflow::str_util::Split(line, ' '); + absl::StripTrailingAsciiWhitespace(&line); + std::vector pieces = absl::StrSplit(line, ' '); CHECK_GE(pieces.size(), 1); auto& platforms = manifest[pieces[0]]; for (int64 i = 1; i < pieces.size(); ++i) { @@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name, // First try full match: test_case_name.test_name // If that fails, try to find just the test_case_name; this would disable all // tests in the test case. - auto it = manifest.find( - tensorflow::strings::StrCat(test_case_name, ".", test_name)); + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); if (it == manifest.end()) { it = manifest.find(test_case_name); if (it == manifest.end()) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2647937013222ccfdae98b0c1d141f461020b5c9..c20a7c8fe49cd6b9161251488b85e08459f68865 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tests/test_utils.h" +#include + +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { @@ -26,89 +29,101 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - // Create uniform numbers between 1 and 1.125 to avoid creating denormal - // numbers. - std::uniform_real_distribution generator(1.0f, 1.125f); - const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice indices) { - // Generate a random uniform number from -0.0625 and 0.0625 and bias it - // with a position dependent number with mean 0.037109375. These number - // should allow for long chains of accumulation without being too close - // to zero or too large to accumulate all numbers accurately. Only do - // this for large literals where the number of elements is much greater - // than 47 otherwise only negative values are produced. - // - // The value is positionally biased using a product of the indices. Add - // one to each index value to avoid collapsing to zero if any of the - // indices are zero. - int64 index_product = 1; - for (int64 i : indices) { - index_product *= (1 + i); - } - const int64 negative_bias = should_index_bias ? 47 : 0; - FloatT index_bias = - static_cast(index_product % 113 - negative_bias) / - static_cast(256.0f); - return static_cast(generator(*engine) - 1.0625f) + index_bias; - })); + if (no_duplicates) { + // Duplicates may be generated if the number of elements in the literal + // exceeds the number of positive values supported by the type. + FloatT next_value = std::numeric_limits::min(); + for (FloatT& value : literal->data()) { + value = next_value; + next_value = + std::nextafter(next_value, std::numeric_limits::max()); + } + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } + } } template void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + PopulateWithRandomFloatingPointDataImpl(literal, engine, + no_duplicates); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for half types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - PopulateWithRandomFloatingPointDataImpl(literal, engine); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (half& value : literal->data()) { + value = static_cast(generator(*engine)); + } } -// The standard library does not have a case for bfloat16, unsurprisingly, so we -// handle that one specially. template <> void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { + std::minstd_rand0* engine, + bool no_duplicates) { + // no_duplicates is ignored for bfloat types. Unique values can only be + // generated for arrays with fewer than ~2**16 elements and no_duplicates is + // best-effort anyway. CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), BF16); - std::uniform_real_distribution generator(-0.9f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return static_cast(generator(*engine)); - })); + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (bfloat16& value : literal->data()) { + value = static_cast(generator(*engine)); + } } template -void PopulateWithRandomIntegralData(Literal* literal, - std::minstd_rand0* engine) { +void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, + bool no_duplicates) { CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(*engine); - })); + if (no_duplicates && ShapeUtil::ElementsIn(literal->shape()) < + std::numeric_limits::max()) { + std::iota(literal->data().begin(), literal->data().end(), 0); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + for (IntT& value : literal->data()) { + value = generator(*engine); + } + } } // Similar to MakeFakeLiteral but takes a random number generator engine to -// enable reusing the engine across randomly generated literals. +// enable reusing the engine across randomly generated literals. 'no_duplicates' +// indicates that there should be no duplicate values in each generated +// array. This is uniqueness is best-effort only. Some types (half and bfloat16) +// are not supported and uniqueness cannot be guaranteed if the number of +// elements exceeds the number of different values supported by the type. StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine) { + const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { std::vector> elements; for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteralInternal(element_shape, engine)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr element, + MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -116,48 +131,60 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = MakeUnique(shape); + auto literal = absl::make_unique(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine); + PopulateWithRandomFloatingPointData(literal.get(), engine, + no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine); + PopulateWithRandomIntegralData(literal.get(), engine, + no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { + TF_CHECK_OK( + literal->Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -167,7 +194,7 @@ StatusOr> MakeFakeLiteralInternal( break; default: return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return std::move(literal); } @@ -176,6 +203,7 @@ enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. ConstantType GetInitValue(const HloComputation& computation) { + // TODO(b/77635120): Add init values, for min, max, and their arg variants. const HloInstruction* const root = computation.root_instruction(); if (computation.num_parameters() != 2 || root->operand_count() != 2 || root->operand(0)->opcode() != HloOpcode::kParameter || @@ -200,24 +228,20 @@ bool NeedsInitValue(const HloUse& use) { const HloInstruction* const instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); const int64 op_num = use.operand_number; - return ( - ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && - op_num == 1) || - (opcode == HloOpcode::kSelectAndScatter && op_num == 2)); + return ((opcode == HloOpcode::kReduceWindow && op_num == 1) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2) || + (opcode == HloOpcode::kReduce && + op_num >= instruction->operand_count() / 2)); } // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector start_indices(rank); +std::unique_ptr MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { + std::vector start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -254,6 +278,11 @@ std::vector FindConstrainedUses( auto converted_uses = FindConstrainedUses(dataflow, *instruction); constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); + } else if (opcode == HloOpcode::kSort && + instruction->operand_count() == 2 && op_num == 0) { + // Operand 0 of sort is the array of keys used for key/value + // (two-operand) kSort instructions. + constrained_uses.push_back(instruction); } } } @@ -265,58 +294,68 @@ std::vector FindConstrainedUses( // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). StatusOr> CreateLiteralForConstrainedUses( - const tensorflow::gtl::ArraySlice constrained_uses, + const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector index_space; + bool no_duplicates = false; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); - } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; + case HloOpcode::kSort: + no_duplicates = true; + break; + default: return Unimplemented( "Constrained operand generation not implemented for %s.", - use->ToString().c_str()); + use->ToString()); } } - if (needs_index != nullptr && needs_constant != nullptr) { - return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + int constraint_count = 0; + constraint_count += no_duplicates ? 1 : 0; + constraint_count += !index_space.empty() ? 1 : 0; + constraint_count += needs_constant ? 1 : 0; + if (constraint_count > 1) { + return Unimplemented("Conflicting operand generation constraints."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); @@ -325,10 +364,11 @@ StatusOr> CreateLiteralForConstrainedUses( case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, + /*no_duplicates=*/false); } } else { - return MakeFakeLiteralInternal(param.shape(), engine); + return MakeFakeLiteralInternal(param.shape(), engine, no_duplicates); } } @@ -345,25 +385,36 @@ StatusOr> MakeConstrainedArgument( StatusOr> MakeFakeLiteral(const Shape& shape, bool pseudo_random) { - auto engine = pseudo_random ? MakeUnique() : nullptr; - return MakeFakeLiteralInternal(shape, engine.get()); + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } StatusOr>> MakeFakeArguments( HloModule* const module, bool pseudo_random) { + auto engine = + pseudo_random ? absl::make_unique() : nullptr; + return MakeFakeArguments(module, engine.get()); +} + +StatusOr>> MakeFakeArguments( + HloModule* const module, std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = + MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); } return std::move(arguments); } -Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { - return HloVerifier(allow_mixed_precision).Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision) { + return HloVerifier(/*layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision) + .Run(module) + .status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index e59f215a9a3ace80d7a23e1bbc40970c7a63ea0d..7790737c093ad8e5a15c017e3f7890b6f25cb6f8 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -20,12 +20,12 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/platform.h" @@ -63,8 +63,17 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // -// Will handle special cases such as making sure that indices used for dynamic -// slices are bounded, reduces that call adds use 0 as an init value, etc. +// A best-effort attempt is made to generate the data in a way which produce +// stable computation results across platforms. Specifically: +// +// (1) Init values of reductions should be the identity of the reduction +// computation. +// +// (2) Indices of dynamic slices and update slices should be in bounds. +// +// (3) Keys of key/value sorts should contain no duplicates. +// +// These constraints are best-effort only. // // If pseudo_random is true, the generated numbers will be generated // deterministically in a pseudo random way unless the values are constrated to @@ -78,10 +87,16 @@ StatusOr> MakeFakeLiteral(const Shape& shape, StatusOr>> MakeFakeArguments( HloModule* const module, bool pseudo_random = true); +// Overload which accepts a random number generator. This enables generation of +// different random values with sequential calls to MakeFakeArguments by reusing +// the same generator. +StatusOr>> MakeFakeArguments( + HloModule* const module, std::minstd_rand0* engine); + // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(HloModule* const module, - bool allow_mixed_precision = false); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a2f0338e25977d7c76dbc48b3afc649b77ba4ee2..322c8ef090cf867f65cada5cb1dbae188f83bad6 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -72,5 +73,106 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get({0}), 0); + + EXPECT_GE(index_arg.Get({1}), 0); + EXPECT_LE(index_arg.Get({1}), 2); + + EXPECT_GE(index_arg.Get({2}), 0); + EXPECT_LE(index_arg.Get({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesFloats) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) { + %parameter.0 = f32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = *args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const float& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + +XLA_TEST_F(TestUtilsTest, NoDuplicatesInt32) { + // Inputs which are sort keys in key/value sorts should have no duplicates. + auto module = ParseHloString(R"( +HloModule sort.148.1589 + +ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) { + %parameter.0 = s32[1048576]{0} parameter(0) + %parameter.1 = s32[1048576]{0} parameter(1) + ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0} +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 2); + const Literal& key_arg = *args[0]; + + tensorflow::gtl::FlatSet key_set; + for (const int32& value : key_arg.data()) { + EXPECT_TRUE(key_set.insert(tensorflow::bit_cast(value)).second); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 2bdbd08309a81b201fc224110805549f7fb5bb55..c7eb9e2dbe0e27b7933f5861280a3401cd268c08 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -84,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { "param")); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -101,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT(status.error_message(), ::testing::HasSubstr( diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 97bbf80aff80e995ea5cdd3e5d8807ee4d380067..f2b3b49015c7d74d786f63776abff1d5181fd961 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -504,9 +505,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) { LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = MakeUnique(sum->shape()); + auto prod = absl::make_unique(sum->shape()); ASSERT_TRUE(prod->Populate( - [&sum](tensorflow::gtl::ArraySlice indexes) { + [&sum](absl::Span indexes) { return sum->Get(indexes) * (indexes[indexes.size() - 1] == 0 ? complex64(1, 2) diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 20ae68ab74026936c43e5f525eb796eb402a19cb..8f80a9f3e466d73f2b718452d9a0d64a80c3b36f 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -190,25 +190,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); } -XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Abs(arg); - - ComputeAndCompareR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); -} - -XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Sign(arg); - - ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); -} - XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 11f3efb1f34ad23ebdcbb65c90aa5fb7a6adeae5..7fd42944debe38abbf6f0ca36bc5c7ecb1aeaf97 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -16,6 +16,10 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -29,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -81,8 +84,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, gtl::FlatMap* parsed_results, - tensorflow::gtl::ArraySlice opcodes_to_ignore = - {}) { + absl::Span opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; @@ -99,7 +101,7 @@ Status ParseOneProfileOutputLine( string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; - string regexp_pattern = tensorflow::strings::StrCat( + string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, match_bytes_per_cycle, separator, match_opcode); @@ -116,7 +118,7 @@ Status ParseOneProfileOutputLine( ", Regexp: ", regexp_pattern); } - if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { + if (!absl::c_linear_search(opcodes_to_ignore, parsed_line.opcode)) { InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); } @@ -169,10 +171,10 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, ServiceExecutableRunOptions run_options( exec_run_options, /*borrow_stream=*/nullptr, backend->eigen_intra_op_thread_pool()); + std::vector args = {&lhs_arg, &rhs_arg}; TF_ASSERT_OK_AND_ASSIGN( auto execution_result, - executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, - &hlo_execution_profile)); + executable->ExecuteOnStream(&run_options, args, &hlo_execution_profile)); TF_ASSERT_OK(stream_ptr->BlockHostUntilDone()); (void)execution_result; @@ -204,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { rhs_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); gtl::FlatMap parsed_profile_lines; @@ -291,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { matrix_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); auto while_body_profile_start = - c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith(s, - "Execution profile for body"); + absl::c_find_if(profile_output_lines, [](absl::string_view s) { + return absl::StartsWith(s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = - std::find_if(while_body_profile_start, profile_output_lines.end(), - [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith( - s, "********** microseconds report **********"); - }); + auto while_body_profile_end = std::find_if( + while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds report **********"); + }); // We emit a blank line before the "********** microseconds report **********" // line. diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a075195618c42aaa11f7b1c17730e67889a2c308..15603619b62d8f45cdce97ac7d83924a78f88cf3 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) { // If the --benchmarks flag is passed in then only run the benchmarks, not the // tests. for (int i = 1; i < argc; i++) { - tensorflow::StringPiece arg(argv[i]); - if (arg == "--benchmarks" || - tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + absl::string_view arg(argv[i]); + if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) { const char* pattern = nullptr; - if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + if (absl::StartsWith(arg, "--benchmarks=")) { pattern = argv[i] + strlen("--benchmarks="); } else { // Handle flag of the form '--benchmarks foo' (no '='). - if (i + 1 >= argc || - tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) { LOG(ERROR) << "--benchmarks flag requires an argument."; return 2; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 897123d7606db60abc1105b03beb3f23ab249579..442e66321ee732f3d9cdfe4931433bd864b7fa82 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -20,25 +20,28 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { StatusOr> TextLiteralReader::ReadPath( - tensorflow::StringPiece path) { - CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) + absl::string_view path) { + CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = @@ -54,33 +57,6 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -namespace { -// This is an optimized version of tensorflow::str_util::Split which uses -// StringPiece for the delimited strings and uses an out parameter for the -// result to avoid vector creation/destruction. -void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, - std::vector* result) { - result->clear(); - - if (text.empty()) { - return; - } - - // The following loop is a little strange: its bound is text.size() + 1 - // instead of the more typical text.size(). - // The final iteration of the loop (when i is equal to text.size()) handles - // the trailing token. - size_t token_start = 0; - for (size_t i = 0; i < text.size() + 1; i++) { - if (i == text.size() || text[i] == delim) { - tensorflow::StringPiece token(text.data() + token_start, i - token_start); - result->push_back(token); - token_start = i + 1; - } - } -} -} // namespace - StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); @@ -90,61 +66,55 @@ StatusOr> TextLiteralReader::ReadAllLines() { return s; } - tensorflow::StringPiece sp(shape_string); - if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = std::string(sp); - shape_string = tmp; - } + absl::StripAsciiWhitespace(&shape_string); TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } - auto result = MakeUnique(shape); + auto result = absl::make_unique(shape); const float fill = std::numeric_limits::quiet_NaN(); result->PopulateWithValue(fill); - std::vector pieces; - std::vector coordinates; + std::vector pieces; + std::vector coordinates; std::vector coordinate_values; string line; while (buf.ReadLine(&line).ok()) { - SplitByDelimToStringPieces(line, ':', &pieces); - tensorflow::StringPiece coordinates_string = pieces[0]; - tensorflow::StringPiece value_string = pieces[1]; - tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); - tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { + pieces = absl::StrSplit(line, ':'); + absl::string_view coordinates_string = + absl::StripAsciiWhitespace(pieces[0]); + absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]); + if (!absl::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( - "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); + "expected '(' at the beginning of coordinates: \"%s\"", line); } - if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { + if (!absl::ConsumeSuffix(&coordinates_string, ")")) { return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", - line.c_str()); + line); } float value; - if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), - &value)) { + if (!absl::SimpleAtof(value_string, &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - std::string(value_string).c_str()); + value_string); } - SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); + coordinates = absl::StrSplit(coordinates_string, ','); coordinate_values.clear(); - for (tensorflow::StringPiece piece : coordinates) { + for (absl::string_view piece : coordinates) { int64 coordinate_value; - if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { + if (!absl::SimpleAtoi(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - std::string(piece).c_str()); + std::string(piece)); } coordinate_values.push_back(coordinate_value); } if (coordinate_values.size() != shape.dimensions_size()) { return InvalidArgument( - "line did not have expected number of coordinates; want %d got %zu: " + "line did not have expected number of coordinates; want %d got %u: " "\"%s\"", - shape.dimensions_size(), coordinate_values.size(), line.c_str()); + shape.dimensions_size(), coordinate_values.size(), line); } result->Set(coordinate_values, value); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 708e8c80d8b5c09454eb64d4e12df51a5b7ea628..b265640802c88847ce57e9f942f9f0859b873ae8 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -41,8 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath( - tensorflow::StringPiece path); + static StatusOr> ReadPath(absl::string_view path); private: // Ownership of file is transferred. diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 24e0784741a4c9779b0adb7a7740c3d6e2fb033a..7289ae7df65e56652eeeb67e536e4c721d97d999 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,23 +17,23 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" namespace xla { -/* static */ Status TextLiteralWriter::WriteToPath( - const Literal& literal, tensorflow::StringPiece path) { +/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal, + absl::string_view path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f); if (!s.ok()) { return s; } @@ -46,16 +46,14 @@ namespace xla { Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( - [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + [f_ptr, &status](absl::Span indices, const string& value) { if (!status.ok()) { return; } - string coordinates = tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(indices, ", "), ")"); + string coordinates = + absl::StrCat("(", absl::StrJoin(indices, ", "), ")"); - status = f_ptr->Append( - tensorflow::strings::StrCat(coordinates, ": ", value, "\n")); + status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n")); }); auto ignored = f->Close(); return status; diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 159ac1b7e1b6f9c07dac795fb640cd0b2d284bcb..34de8572d638067b327711017ee173b16c8da21e 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -37,8 +37,7 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, absl::string_view path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 40d28a57bfddd3403cad8252df985b746362631f..3a086c66bbb37965b1ad7c83a93f0054ae723e87 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -24,6 +24,8 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", ], ) @@ -42,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -67,6 +70,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -94,6 +98,7 @@ cc_library( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], alwayslink = True, ) @@ -172,6 +177,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -191,6 +197,9 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", ], ) @@ -210,6 +219,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index f20dcef382b86d27d7c176ae7e4132ad1db7b901..c866a13de7543fc948311f94708bc6b904717b62 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -46,7 +46,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -77,8 +77,8 @@ int main(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index f0af0580c1fbca455c6ed5f87f82971faee50a06..4375e7c138c9e8d193feaa7a39d63946c4ea3086 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -29,9 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -44,16 +44,14 @@ class OperationDumper : public DfsHloVisitorWithDefault { explicit OperationDumper(const string& path) : path_(path) {} Status DefaultAction(HloInstruction* hlo) override { - string params = tensorflow::str_util::Join( + string params = absl::StrJoin( hlo->operands(), ", ", [](string* out, const HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + absl::StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); // Spit `op_name(params...) -> result_type :: path` to stdout. - std::cout << tensorflow::strings::Printf( - "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(), - params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(), - path_.c_str()); + std::cout << absl::StrFormat("%s :: (%s) -> %s :: %s\n", + HloOpcodeString(hlo->opcode()), params, + ShapeUtil::HumanString(hlo->shape()), path_); return Status::OK(); } @@ -61,7 +59,7 @@ class OperationDumper : public DfsHloVisitorWithDefault { string path_; }; -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -106,8 +104,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index f03e1b1f965af761c101555fd0275bc0425b9cf0..723569862c7550387e95003e3a673743464b67b8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -34,7 +34,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { +void RealMain(absl::Span args, bool compile) { LocalClient* client = ClientLibrary::LocalClientOrDie(); LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); @@ -102,8 +102,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc..07ef5ff656bb48519a700a1d7d6c60b655a40ed6 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ using tensorflow::Env; namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -78,8 +78,8 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index eb7bff053b1fc028fdb6930dbc496c3b6d9fae47..23ce1d235b9f2613505f8a3bfbd1a4c1162debd4 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" @@ -67,9 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - tensorflow::StringPiece content( - tensorflow::bit_cast(floats.data()), - floats.size() * sizeof(float)); + absl::string_view content(absl::bit_cast(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index b4774233e588dc407bfb88defca9bf55e08eea09..ba814af4769f43dbe96190c902cf6f52ca5659bb 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -40,6 +40,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -59,7 +60,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -160,7 +160,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, // concurrent infeed occur via the fake_infeed_shape, or when // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. - tensorflow::gtl::optional pool; + absl::optional pool; std::unique_ptr data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); @@ -196,7 +196,7 @@ StatusOr ReplayComputation(const HloSnapshot& module, StreamExecutorMemoryAllocator allocator( client->platform(), {client->platform()->ExecutorForDevice(0).ValueOrDie()}); - tensorflow::gtl::optional result; + absl::optional result; for (int i = 0; i < opts.num_runs; ++i) { // If xla_hlo_profile is enabled, print a noisy message before the last run, // making it easier to separate this profile from the others in the logspam. @@ -250,10 +250,10 @@ StatusOr ParseInputFile(const string& filename, } fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", filename.c_str()); - return InvalidArgument("Could not parse %s.", filename.c_str()); + return InvalidArgument("Could not parse %s.", filename); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { +int RealMain(absl::Span args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; @@ -344,7 +344,7 @@ int main(int argc, char** argv) { LOG(QFATAL) << usage; } - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 4e53fafcc97ff53afc5713e7ed8ee5222fac316b..cdf306dfd1027cf6022c5d8ae844b4308f580e8d 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -45,7 +45,7 @@ limitations under the License. namespace xla { namespace tools { -void RealMain(tensorflow::gtl::ArraySlice args) { +void RealMain(absl::Span args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { HloSnapshot module; @@ -66,8 +66,8 @@ void RealMain(tensorflow::gtl::ArraySlice args) { int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + absl::Span args(argv, argc); + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e43498e381b8e63543e2ddda08ca7c0df91817e4..68cab7387cf1576072f96878b50f07def6862d8b 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" @@ -54,111 +55,28 @@ ScopedLoggingTimer::~ScopedLoggingTimer() { } } -Status AddStatus(Status prior, tensorflow::StringPiece context) { +Status AddStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat( - context, ": ", prior.error_message())}; + return Status{prior.code(), + absl::StrCat(context, ": ", prior.error_message())}; } -Status AppendStatus(Status prior, tensorflow::StringPiece context) { +Status AppendStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(), - ": ", context)}; + return Status{prior.code(), + absl::StrCat(prior.error_message(), ": ", context)}; } -// Implementation note: we can't common these out (without using macros) because -// they all need to va_start/va_end their varargs in their frame. - -Status InvalidArgumentV(const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); -} - -Status InvalidArgument(const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -Status Unimplemented(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unimplemented(message)); -} - -Status InternalError(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Internal(message)); -} - -Status FailedPrecondition(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message)); -} - -Status Cancelled(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Cancelled(message)); -} - -Status ResourceExhausted(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message)); -} - -Status NotFound(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::NotFound(message)); -} - -Status Unavailable(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unavailable(message)); -} - -string Reindent(tensorflow::StringPiece original, - const tensorflow::StringPiece indentation) { - std::vector pieces = tensorflow::str_util::Split( - tensorflow::StringPiece(original.data(), original.size()), '\n'); - return tensorflow::str_util::Join( - pieces, "\n", [indentation](string* out, string s) { - tensorflow::StringPiece piece(s); - tensorflow::str_util::RemoveWhitespaceContext(&piece); - tensorflow::strings::StrAppend(out, indentation, piece); - }); +string Reindent(absl::string_view original, + const absl::string_view indentation) { + std::vector pieces = + absl::StrSplit(absl::string_view(original.data(), original.size()), '\n'); + return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) { + absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s)); + }); } -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { +bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } @@ -172,7 +90,7 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { } std::vector InversePermutation( - tensorflow::gtl::ArraySlice input_permutation) { + absl::Span input_permutation) { DCHECK(IsPermutation(input_permutation, input_permutation.size())); std::vector output_permutation(input_permutation.size(), -1); for (size_t i = 0; i < input_permutation.size(); ++i) { @@ -181,8 +99,8 @@ std::vector InversePermutation( return output_permutation; } -std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, - tensorflow::gtl::ArraySlice p2) { +std::vector ComposePermutations(absl::Span p1, + absl::Span p2) { CHECK_EQ(p1.size(), p2.size()); std::vector output; for (size_t i = 0; i < p1.size(); ++i) { @@ -191,7 +109,7 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { +bool IsIdentityPermutation(absl::Span permutation) { for (int64 i = 0; i < permutation.size(); ++i) { if (permutation[i] != i) { return false; @@ -212,7 +130,7 @@ PaddingConfig MakeNoPaddingConfig(int64 rank) { } PaddingConfig MakeEdgePaddingConfig( - tensorflow::gtl::ArraySlice> padding) { + absl::Span> padding) { PaddingConfig padding_config; for (const std::pair& dim : padding) { auto dimension = padding_config.add_dimensions(); @@ -234,20 +152,20 @@ bool HasInteriorPadding(const PaddingConfig& config) { namespace { string HumanReadableNumOps(double flops, double nanoseconds, - tensorflow::StringPiece op_prefix) { + absl::string_view op_prefix) { if (nanoseconds == 0) { - return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s"); + return absl::StrCat("NaN ", op_prefix, "OP/s"); } double nano_flops = flops / nanoseconds; string throughput = tensorflow::strings::HumanReadableNum( static_cast(nano_flops * 1e9)); - tensorflow::StringPiece sp(throughput); + absl::string_view sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case - tensorflow::str_util::EndsWith(sp, "b")) { + if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case + absl::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } - throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); + throughput += absl::StrCat(op_prefix, "OP/s"); return throughput; } } // namespace @@ -260,8 +178,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) { return HumanReadableNumOps(trops, nanoseconds, "TR"); } -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno) { +void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { const int orig_sev = sev; if (sev == tensorflow::FATAL) { sev = tensorflow::ERROR; @@ -275,7 +192,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, size_t cur = 0; while (cur < text.size()) { size_t eol = text.find('\n', cur); - if (eol == tensorflow::StringPiece::npos) { + if (eol == absl::string_view::npos) { eol = text.size(); } auto msg = text.substr(cur, eol - cur); @@ -290,14 +207,13 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, } } -int64 Product(tensorflow::gtl::ArraySlice xs) { +int64 Product(absl::Span xs) { return std::accumulate(xs.begin(), xs.end(), static_cast(1), std::multiplies()); } -std::vector> CommonFactors( - tensorflow::gtl::ArraySlice a, - tensorflow::gtl::ArraySlice b) { +std::vector> CommonFactors(absl::Span a, + absl::Span b) { CHECK_EQ(Product(a), Product(b)); if (0 == Product(a)) { return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())}; diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 5ae099a4622bb7116c7a17f93060b699ead6e3a6..8ce741647414a1fa75e6d706ec1e719ace7b7cc8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,17 +24,20 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -54,7 +57,7 @@ Status WithLogBacktrace(const Status& status); // the InlinedVector will just behave like an std::vector<> and allocate the // memory to store its values. static constexpr int kInlineRank = 8; -using DimensionVector = tensorflow::gtl::InlinedVector; +using DimensionVector = absl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it @@ -98,65 +101,63 @@ struct ScopedLoggingTimer { uint64 start_micros; }; -// Given a vector, returns a MutableArraySlice that points at its +// Given a vector, returns a Span that points at its // internals. // // Warning: if the vector is updated its storage pointer may change, so use this // with caution (ideally in limited scopes with temporary lifetimes). template -tensorflow::gtl::MutableArraySlice MutableByteSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(v->data()), v->size() * sizeof(T)); +absl::Span MutableByteSlice(std::vector* v) { + return absl::Span(reinterpret_cast(v->data()), + v->size() * sizeof(T)); } // Turns an immutable slice of type T into an immutable slice of bytes with the // same byte size. template -tensorflow::gtl::ArraySlice CastToByteSlice( - tensorflow::gtl::ArraySlice slice) { - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() * sizeof(T)); +absl::Span CastToByteSlice(absl::Span slice) { + return absl::Span(reinterpret_cast(slice.data()), + slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice // length is a multiple of sizeof(T). template -tensorflow::gtl::ArraySlice CastByteSlice( - tensorflow::gtl::ArraySlice slice) { +absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size() / sizeof(T)); + return absl::Span(reinterpret_cast(slice.data()), + slice.size() / sizeof(T)); } // Convenience function to force a vector to convert to an immutable slice. template -tensorflow::gtl::ArraySlice AsSlice(const std::vector& v) { - return tensorflow::gtl::ArraySlice(v); +absl::Span AsSlice(const std::vector& v) { + return absl::Span(v); } -// Converts a mutable vector pointer into a MutableArraySlice of the same +// Converts a mutable vector pointer into a Span of the same // type. template -tensorflow::gtl::MutableArraySlice AsMutableSlice(std::vector* v) { - return tensorflow::gtl::MutableArraySlice(v->data(), v->size()); +absl::Span AsMutableSlice(std::vector* v) { + return absl::Span(v->data(), v->size()); } // xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source. // Wrapper function that gives an int64 array slice view of a repeated int64 // protobuf field. -static inline tensorflow::gtl::ArraySlice AsInt64Slice( +static inline absl::Span AsInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // As above, but for uint64 types. -static inline tensorflow::gtl::ArraySlice AsUInt64Slice( +static inline absl::Span AsUInt64Slice( const tensorflow::protobuf::RepeatedField& v) { - tensorflow::gtl::ArraySlice slice(v); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(slice.data()), slice.size()); + absl::Span slice(v); + return absl::Span(reinterpret_cast(slice.data()), + slice.size()); } // Compares two containers for equality. Returns true iff the two containers @@ -172,7 +173,7 @@ template bool ContainersEqual(const Container1T& c1, std::initializer_list il) { - tensorflow::gtl::ArraySlice c2{il}; + absl::Span c2{il}; return ContainersEqual(c1, c2); } @@ -190,9 +191,9 @@ bool ContainersEqual(const Container1T& c1, const Container2T& c2, // source and destination. The source starting index is src_base, while the // destination one is dest_base. template -void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, - int64 dest_stride, tensorflow::gtl::ArraySlice src, - int64 src_base, int64 src_stride, int64 count) { +void StridedCopy(absl::Span dest, int64 dest_base, int64 dest_stride, + absl::Span src, int64 src_base, int64 src_stride, + int64 count) { for (; count > 0; --count, dest_base += dest_stride, src_base += src_stride) { dest[dest_base] = static_cast(src[src_base]); } @@ -201,46 +202,76 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. -Status AddStatus(Status prior, tensorflow::StringPiece context); -Status AppendStatus(Status prior, tensorflow::StringPiece context); - -// Status error shorthands -- printfs the arguments to be -// used as an error message and returns a status in the canonical -// error space. -Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Cancelled(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); - -// Passed-varargs variant of the InvalidArgument factory above. -Status InvalidArgumentV(const char* format, va_list args); +Status AddStatus(Status prior, absl::string_view context); +Status AppendStatus(Status prior, absl::string_view context); + +// Status error shorthands -- StrFormat's the arguments to be used as an error +// message and returns a status in the canonical error space. +template +Status InvalidArgument(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...))); +} +template +Status Unimplemented(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unimplemented(absl::StrFormat(format, args...))); +} +template +Status InternalError(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Internal(absl::StrFormat(format, args...))); +} +template +Status FailedPrecondition(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...))); +} +template +Status Cancelled(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Cancelled(absl::StrFormat(format, args...))); +} +template +Status ResourceExhausted(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...))); +} +template +Status NotFound(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::NotFound(absl::StrFormat(format, args...))); +} +template +Status Unavailable(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unavailable(absl::StrFormat(format, args...))); +} template Status InvalidArgumentStrCat(Args&&... concat) { - return InvalidArgument( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InvalidArgument("%s", absl::StrCat(std::forward(concat)...)); } template Status UnimplementedStrCat(Args&&... concat) { - return Unimplemented( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return Unimplemented("%s", absl::StrCat(std::forward(concat)...)); } template Status InternalErrorStrCat(Args&&... concat) { - return InternalError( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InternalError("%s", absl::StrCat(std::forward(concat)...)); } template Status ResourceExhaustedStrCat(Args&&... concat) { - return ResourceExhausted( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return ResourceExhausted("%s", absl::StrCat(std::forward(concat)...)); } // Splits the lines of the original, replaces leading whitespace with the prefix @@ -249,11 +280,10 @@ Status ResourceExhaustedStrCat(Args&&... concat) { // // Note: even different amounts of leading whitespace on different lines will be // uniformly replaced with "indentation". -string Reindent(tensorflow::StringPiece original, - tensorflow::StringPiece indentation); +string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. -bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); +bool IsPermutation(absl::Span permutation, int64 rank); // Applies `permutation` on `input` and returns the permuted array. // For each i, output[permutation[i]] = input[i]. @@ -261,10 +291,11 @@ bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); // Precondition: // 1. `permutation` is a permutation of 0..permutation.size()-1. // 2. permutation.size() == input.size(). -template