diff --git a/RELEASE.md b/RELEASE.md index 763ef3b279dde209ed387534032deae40a33a9e4..bdc23795e55800a885386ab8d63b032fa4979149 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,9 @@ +# Release 1.10.1 +## Bug Fixes and Other Changes + +* `tf.keras`: + * Fixing keras on Cloud TPUs. No new binaries will be built for Windows. + # Release 1.10.0 ## Major Features And Improvements diff --git a/configure.py b/configure.py index 361bd4764dc5c1900be7378f51c00aedf6f2ce41..52a513779e601482d673297ed08e43133c5ad3c7 100644 --- a/configure.py +++ b/configure.py @@ -852,7 +852,7 @@ def set_tf_cuda_version(environ_cp): # Reset and retry print('Invalid path to CUDA %s toolkit. %s cannot be found' % - (tf_cuda_version, cuda_toolkit_path_full)) + (tf_cuda_version, cuda_toolkit_paths_full)) environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index b5e0a4e98b0c183454afa4a4389dcf73802b219b..386e0096ff705c2eaa98f42833ef650bac6fc8d8 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -12,6 +12,7 @@ exports_files([ # The leakr files are used by //third_party/cloud_tpu. "leakr_badwords.dic", "leakr_badfiles.dic", + "leakr_file_type_recipe.ftrcp", ]) load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") @@ -23,6 +24,11 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused @@ -32,6 +38,11 @@ load( "if_ngraph", ) +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -427,12 +438,20 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", + "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", "//third_party/py/tensor2tensor/...", ], @@ -590,13 +609,39 @@ exports_files( ) gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], api_version = 1, + output_dir = "_api/v1/", output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + +gen_api_init_files( + name = "tf_python_api_gen_v2", + srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + py_library( name = "tensorflow_py", srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], @@ -611,7 +656,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 779f65d5b17c350833f67f07985b00e8eb561e72..53a72b84430ac703323e8235b4e3393d1c9898bc 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -18,11 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os as _os + # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import try: - import os # pylint: disable=g-import-not-at-top # Add `estimator` attribute to allow access to estimator APIs via # "tf.estimator..." from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top @@ -30,9 +31,8 @@ try: # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." # style imports. from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top - __path__ += [os.path.dirname(estimator_api.__file__)] + __path__ += [_os.path.dirname(estimator_api.__file__)] del estimator_api - del os except (ImportError, AttributeError): print('tf.estimator package not installed.') @@ -45,6 +45,12 @@ del LazyLoader from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top app.flags = flags # pylint: disable=undefined-variable +# Make sure directory containing top level submodules is in +# the __path__ so that "from tensorflow.foo import bar" works. +_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable +if _tf_api_dir not in __path__: + __path__.append(_tf_api_dir) + del absolute_import del division del print_function @@ -54,6 +60,12 @@ del print_function # must come from this module. So python adds these symbols for the # resolution to succeed. # pylint: disable=undefined-variable -del python -del core +try: + del python + del core +except NameError: + # Don't fail if these modules are not available. + # For e.g. we are using this file for compat.v1 module as well and + # 'python', 'core' directories are not under compat/v1. + pass # pylint: enable=undefined-variable diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 109b3b37aace34914e5307981ead597c25c7fb8f..43c279bd800d79eeaf9a25bbc1978148f93c0a50 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -204,6 +204,7 @@ tf_cuda_cc_test( "//tensorflow:darwin": ["-headerpad_max_install_names"], "//conditions:default": [], }), + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 173bbea596a4276559f5cd67824e5cc75313985c..79811ceae57e0bddeb2a6f32bad7003e14e23422 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 69b3ffe2a1f620e346405607ecf742fb863aa644..c195c9e01ca920c7234499b6e1d5e9cbf24056f3 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -79,6 +80,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation, auto* gpu_options = config.mutable_gpu_options(); gpu_options->set_allow_growth(gpu_memory_allow_growth); + // TODO(b/113217601): This is needed for EagerContext::runner_ to use a + // threadpool, so that we avoid the possibility of running the runner_ in the + // threadpool of GPU event mgr, as that can trigger more callbacks to be + // scheduled on that same threadpool, causing a deadlock in cases where the + // caller of event_mgr->ThenExecute() blocks on the completion of the callback + // (as in the case of ConstOp kernel creation on GPU, which involves copying a + // CPU tensor to GPU). + // Setting a larger thread pool does not help with the Swift caller, as we use + // a different TFE context for each thread of execution (for running graph + // functions, and their send/recvs corountines). + config.set_inter_op_parallelism_threads(1); + TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(config, ret)); return ret; @@ -8494,3 +8507,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, /*run_metadata*/ nullptr, status); VLOG(1) << "Enqueuing is done."; } + +TFE_Context* TFE_CreateContextFromSession(TF_Session* session, + TF_Status* status) { + auto* opts = TFE_NewContextOptions(); + + // Reduce GPU memory allocation, and set appropriate config options for TFE + // context. + auto* config = + TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true); + TFE_ContextOptionsSetConfig(opts, config->data, config->length, status); + if (!status->status.ok()) { + CHECK(!config); + TFE_DeleteContextOptions(opts); + return nullptr; + } + + auto* ctx = TFE_NewContextFromSession(opts, session, status); + TF_DeleteBuffer(config); + TFE_DeleteContextOptions(opts); + return ctx; +} + +// TODO: retrieve the device string via TFE_ContextListDevices() +static const char DEFAULT_CPU_DEVICE[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + +static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType, + int tensor_id, TF_Status* status) { + std::unique_ptr queueOp( + TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp); + TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler. + TFE_OpSetAttrInt(queueOp.get(), "capacity", 1); + TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1); + auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id); + TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(), + shared_name.size()); + TFE_OpSetAttrString(queueOp.get(), "container", "", 0); + + // TODO: consider making this an unknown shape. + const int64_t* dims_ptr = nullptr; + int num_dims = 0; + TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims, + /*num_values*/ 0, status); + if (!status->status.ok()) return nullptr; + + int num_retvals = 1; + TFE_TensorHandle* queue = nullptr; + TFE_Execute(queueOp.get(), &queue, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + + return queue; +} + +static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType, + TFE_TensorHandle* queue, TFE_TensorHandle* tensor, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status); + if (!status->status.ok()) return; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return; + TFE_OpAddInput(op, tensor, status); + if (!status->status.ok()) return; + TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + + int num_retvals = 0; + TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status); + if (!status->status.ok()) return; + CHECK_EQ(num_retvals, 0); +} + +static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx, + TF_DataType inputType, + TFE_TensorHandle* queue, + TF_Status* status) { + TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status); + if (!status->status.ok()) return nullptr; + std::unique_ptr op_deleter(op, TFE_DeleteOp); + TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status); + if (!status->status.ok()) return nullptr; + + TFE_OpAddInput(op, queue, status); + if (!status->status.ok()) return nullptr; + TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1); + TFE_OpSetAttrInt(op, "timeout_ms", -1); + TFE_TensorHandle* ret; + int num_retvals = 1; + TFE_Execute(op, &ret, &num_retvals, status); + if (!status->status.ok()) return nullptr; + CHECK_EQ(num_retvals, 1); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + assert(session); + VLOG(1) << "Dequeuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + return ret; +} + +TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TF_DataType inputType, + TF_Status* status) { + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + auto* ret = createTFEDequeue(ctx, inputType, queue, status); + + return ret; +} + +void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + assert(session); + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status) { + VLOG(1) << "Enqueuing data tensor with id " << tensor_id; + + TF_DataType inputType = TFE_TensorHandleDataType(tensor); + TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, inputType, queue, tensor, status); +} + +void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id, + TFE_TensorHandle* tensor, TF_Status* status) { + VLOG(1) << "Enqueuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status); +} + +TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + VLOG(1) << "Dequeuing variant tensor with id " << tensor_id; + + auto ctx = TFE_CreateContextFromSession(session, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr ctx_deleter( + ctx, TFE_DeleteContext); + + TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status); + if (!status->status.ok()) return nullptr; + std::unique_ptr + queue_deleter(queue, TFE_DeleteTensorHandle); + + return createTFEDequeue(ctx, TF_VARIANT, queue, status); +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 09d482d6df45aa95a2f463f1c9601048bea24c04..522c91f67efdf10118268842dee3beb334fb720d 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -132,9 +132,48 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, TF_Tensor* tensor, TF_Status* status); +// TODO: remove this API in favor of the next one. TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession( const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status); +// Creates from `session` a new eager context to run a graph function or +// sends/recvs, so that these concurrent TFE executions can share (via +// `session` and its associated device mgr) the same set of fifo queue resource +// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and +// graph function execution can access the same fifo queue resource handles +// (associated with devices managed by the device manager, which can be obtained +// from `session`). +// +// TODO: Remove this function once we migrate away from using session. +TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession( + TF_Session* session, TF_Status* status); + +// TODO: Retire this API in favor of the next one. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor( + TF_Session* session, int tensor_id, TF_DataType inputType, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx( + TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor, + TF_Status* status); + +// TODO: consider folding the 2 APIs below into the ones above. +TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session, + int tensor_id, + TFE_TensorHandle* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor( + TF_Session* session, int tensor_id, TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a2c5a42c11361779de61b515e0f08dcc45e609b9..f68f8a3e90a971b5e4a024feaf26ba498afc48da 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/strings/base64.h" diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 77e3878a94eddfa1dfd53844916f453d70bcac4a..349d9bcd7ca3991c7c3621f347af6025778612b7 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index eec2750d6eb3bceed8da3ed44812ac2e8fd5c877..337447eec9581b01fa775affc49097986824a360 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf529023eadf38984668a40118641bb1b..55331022b9dbd0696928fa44430f340f371432ac 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2de5ad63d11eb8979ae64c26b91366f..0717e7dd4b358d6c212070374bcc3fd2f91ed0ab 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2b1ce34b3770a47e31d4f623b1b4f4650206737e..b17bc658fa06b9feb7edb292bd89ef31e6309169 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { @@ -135,12 +135,12 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, indices = "[0]"; } else { for (int dim = 0; dim < shape.dimensions_size(); ++dim) { - dim_vars.push_back(strings::StrCat("size_t dim", dim)); - dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]"); - indices += strings::StrCat("[dim", dim, "]"); + dim_vars.push_back(absl::StrCat("size_t dim", dim)); + dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); + indices += absl::StrCat("[dim", dim, "]"); } } - rewrites->push_back({"{{I}}", strings::StrCat(i)}); + rewrites->push_back({"{{I}}", absl::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); @@ -194,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, arg_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites); } @@ -235,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config, result_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites); } @@ -304,8 +304,8 @@ std::vector BufferInfosToCppExpression( string encoded_second_as_str = encoded.second == ~0ULL ? "~0ULL" - : strings::StrCat(encoded.second, "ULL"); - return strings::StrCat( + : absl::StrCat(encoded.second, "ULL"); + return absl::StrCat( "::tensorflow::cpu_function_runtime::BufferInfo({", encoded.first, "ULL, ", encoded_second_as_str, "})"); }); @@ -352,13 +352,13 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { - ns_start += strings::StrCat("namespace ", n, " {\n"); + ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { const string& n = opts.namespaces[i]; - ns_end += strings::StrCat("} // end namespace ", n, "\n"); + ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. @@ -568,10 +568,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { )"; // The replacement strategy is naive, but good enough for our purposes. const std::vector> rewrites = { - {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, - {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, + {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, + {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, @@ -590,11 +590,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, - {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, + {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, - {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, - {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, + {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, + {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)}, + {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", absl::StrJoin(buffer_infos_as_strings, ",\n")}}; absl::StrReplaceAll(rewrites, header); @@ -602,13 +602,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { } static string CreateUniqueIdentifier(const CodegenOpts& opts, - StringPiece suffix) { + absl::string_view suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { - strings::StrAppend(&result, "_", n); + absl::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_", suffix); + absl::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -678,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return Status::OK(); } -Status ValidateCppIdent(StringPiece ident, StringPiece msg) { +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { if (ident.empty()) { return errors::InvalidArgument("empty identifier: ", msg); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 83f2d3ee11d09d66f16d7ecdc11945ebe994a82a..90410c46a8e36e44454f1219ad76d0fb0937070d 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { namespace tfcompile { @@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. -Status ValidateCppIdent(StringPiece ident, StringPiece msg); +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index e3a53edb7368c209bea16a9e34b1f452a8ff4bf8..bb288d23000527be74f01630d20bbf82e50007ce 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index f1e8e5c08482e15d989c19a43aa7c5f437cd091d..3c32d533f63f202fc9409f36709e0d29d1d7e002 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - StringPiece unique_identifier, string* protobuf_array_symbol_name, + absl::string_view unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = - strings::StrCat(unique_identifier, "_protobuf_array_contents"); + absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); llvm::Constant* protobuf_array_initializer = @@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, - StringPiece protobuf_array_symbol_name, - int64 protobuf_array_size) { +static string CreateCPPShimExpression( + absl::string_view qualified_cpp_protobuf_name, + absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) { string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" @@ -68,9 +68,9 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, return absl::StrReplaceAll( code, { - {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, - {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, - {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)}, }); } @@ -93,7 +93,7 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, } static StatusOr> -GetTargetMachineFromTriple(StringPiece target_triple) { +GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; std::string normalized_triple = llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); @@ -110,7 +110,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) { } StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, + absl::string_view target_triple, absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); @@ -135,8 +135,8 @@ StatusOr CreateEmbeddedProtocolBuffers( protobuf_to_embed.qualified_cpp_protobuf_name, protobuf_array_symbol_name, protobuf_array_size); - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); + cpp_variable_decl = + absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];"); } else { cpp_shim = "nullptr"; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4f940c019750f49da4ad2386aa4b23281cc5a9fc..cf5c04ac4bdff73b76a365c346f7db60ce2d8197 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -83,7 +83,7 @@ struct ProtobufToEmbed { // is stored in the object_file_data field in the returned // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, + absl::string_view target_triple, absl::Span protobufs_to_embed); } // namespace tfcompile diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 723e9bec8afcfbf7ceeeb59c63e4e12442fdb7ab..7a0932d44d405de0f2edf072f4760126bff36719 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -67,7 +67,12 @@ genrule( "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", ], - cmd = "$(location :make_test_graphs) --out_dir $(@D)", + # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any + # GPUs which might be present. This is important because builds may run + # concurrently with tests, and tests need to be able to assume that they + # have control of the full GPU. + cmd = "CUDA_VISIBLE_DEVICES='' " + + "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], tools = [":make_test_graphs"], ) @@ -226,6 +231,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_profile_printer", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index dd2b151098f2054571ac32b8b506cbc00659588a..7ac90fb8a9c73bdbc149f263d7d229a6514769f8 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) { string hlo_profile_as_string = xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), /*clock_rate_ghz=*/1.0); - VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; + + // Strip away identifier details from the profile string to avoid this test + // being a change detector for xla internals. Identifiers such as '%dot.0.7' + // just become '%dot'. + RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); + VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = absl::StrSplit(hlo_profile_as_string, '\n'); @@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " - "%arg1.0.1)"); + "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)"); auto tuple_profile_line = HasSubstr( - "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, " + "f32[2,2]{1,0} %add)"); + auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 326f73b975aec3a7a6bc7cdc9a92f540ad545ad6..792b7fe14abf91626a0aeb75cdbe319b123ec10c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -105,12 +105,18 @@ def tf_library( freeze_file = freeze_name + ".pb" # First run tfcompile to generate the list of out_nodes. + # + # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we + # launch from using any GPUs which might be present. This is important + # because builds may run concurrently with tests, and tests need to be + # able to assume that they have control of the full GPU. out_nodes_file = "out_nodes_" + freeze_name native.genrule( name = ("gen_" + out_nodes_file), srcs = [config], outs = [out_nodes_file], - cmd = ("$(location " + tfcompile_tool + ")" + + cmd = ("CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), tools = [tfcompile_tool], @@ -142,9 +148,12 @@ def tf_library( out_nodes_file, ] + freeze_saver_srcs, outs = [freeze_file], - cmd = ("$(location " + - "//tensorflow/python/tools:freeze_graph)" + - freeze_args), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args + ), tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) @@ -177,16 +186,19 @@ def tf_library( metadata_object_file, function_object_file, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -216,14 +228,17 @@ def tf_library( outs = [ session_module_pb, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f3c44e9dda8ce96a268420a7f4d0f22e50ddfe41..b95b063348c5cdfdcaed635ba527e9f0bfd6092d 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -92,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index df81f3c23e38a2ec2cea827cd0adb123855e7714..7d5db713f678b696131ff4074d54e3776f019e02 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -265,6 +265,7 @@ cc_library( srcs = ["jit_compilation_pass_registration.cc"], deps = [ ":compilation_passes", + "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu_internal", ], alwayslink = 1, @@ -395,6 +396,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) @@ -410,6 +412,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -479,6 +482,7 @@ tf_cc_test( ":common", ":compilation_passes", ":xla_cluster_util", + ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -495,6 +499,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -566,6 +572,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 82aa03810bc0ecee8ae92ed6f286867eea893287..9128b48da3fe9dd3d85d146e16c153c1b3bebf4c 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -154,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -185,7 +185,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } @@ -206,7 +206,7 @@ class NotPredicate : public Predicate { operands_({operand}) {} string ToString() const override { - return strings::StrCat("~", operand()->ToString()); + return absl::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } @@ -240,8 +240,8 @@ class AndRecurrencePredicate : public Predicate { Predicate* step() const { return operands_[1]; } string ToString() const override { - return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -267,7 +267,7 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { - return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 2788102620546d8eab657c519f078c5b03e265cc..ae7a22f4516fc6c87c0c555214eacac71f2ea0d7 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -755,7 +755,7 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -790,7 +790,7 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -950,16 +950,15 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), kHostComputeOp); builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", - strings::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -1017,8 +1016,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; - NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), - "NoOp"); + NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); Status s = builder.Finalize(&seq_def); @@ -1091,10 +1089,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); + dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), + *graph_, library); + dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), + fdef); } if (!reuse_existing_functions || library->Find(name) == nullptr) { @@ -1130,8 +1128,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shapes", shapes); } else { string inference_graph_name = - strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); + absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); FunctionDef fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); @@ -1155,10 +1153,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; dump_graph::DumpGraphToFile( - strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, library); dump_graph::DumpFunctionDefToFile( - strings::StrCat("replace_encapsulate_fdef_", name), fdef); + absl::StrCat("replace_encapsulate_fdef_", name), fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1186,8 +1184,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - strings::StrCat(call_node_def_.name(), "_key_placeholder"), - "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1221,16 +1218,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); builder.Attr(group_attribute, subgraph_name); builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); @@ -1276,13 +1273,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), kSendFromHostOp); builder.Device(device_); builder.Attr("Tinputs", dtypes); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); @@ -1516,7 +1513,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { // Dump subgraphs. for (auto& entry : subgraphs_) { dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } } @@ -2052,7 +2049,7 @@ struct PathDetails { struct SubgraphAndClusterHash { inline std::size_t operator()(const SubgraphAndCluster& v) const { return hash()( - strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + absl::StrCat(v.subgraph, v.outside_compilation_cluster)); } }; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 7bc0ef030302dc6495e3e6d1151f458b450ed2c3..49958093b8dcf35e8adcdfd2f7dfce8558d5db6f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "absl/strings/match.h" @@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, FunctionDef* fdef = library->add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( *graph, - strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); return Status::OK(); } @@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = b.find(elt_a.first); if (iter == b.end()) { if (diff) { - *diff = strings::StrCat( - map_name, " expected: contains element with key '", - key_to_string(elt_a.first), "' got: map has no such element"); + *diff = absl::StrCat(map_name, " expected: contains element with key '", + key_to_string(elt_a.first), + "' got: map has no such element"); } return false; } if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { - *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), "' has value '", - value_to_string(elt_a.second), "' got: '", - value_to_string(iter->second), "'"); + *diff = absl::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), "' has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); } return false; } @@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = a.find(elt_b.first); if (iter == a.end()) { if (diff) { - *diff = strings::StrCat(map_name, " got: contains element with key '", - key_to_string(elt_b.first), - "' expected: map has no such element"); + *diff = absl::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); } return false; } @@ -99,25 +100,25 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, const string& diff_preamble, string* diff) { if (a.op() != b.op()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected op '", a.op(), "' got '", b.op()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); } return false; } if (a.device() != b.device()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected device '", a.device(), "' got '", - b.device()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); } return false; } if (a.input_size() != b.input_size()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected ", a.input_size(), " inputs got ", - b.input_size(), " expected:\n", a.DebugString(), - "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } @@ -127,10 +128,10 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { if (diff) { - *diff = strings::StrCat( - diff_preamble, " mismatch for node ", a.name(), " input ", i, - ", expected control input ", a.input(i), " got ", b.input(i), - " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected control input ", + a.input(i), " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, control_input_b.insert(b.input(i)); } else if (a.input(i) != b.input(i)) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " input ", i, ", expected ", a.input(i), - " got ", b.input(i), " expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), " got ", + b.input(i), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } } if (control_input_a != control_input_b) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " control inputs differ expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return av.DebugString() == bv.DebugString(); } }, - strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), - diff); + absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Signature mismatch for function ", - a.signature().name(), ", expected:\n", - a.signature().DebugString(), "\ngot:\n", - b.signature().DebugString()); + *diff = + absl::StrCat("Signature mismatch for function ", a.signature().name(), + ", expected:\n", a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } @@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, - strings::StrCat("attr mismatch for function ", a.signature().name()), + absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } @@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const string& av, const string& bv) { return av == bv; }, - strings::StrCat("ret mismatch for function ", a.signature().name()), + absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; } @@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (a.node_def(i).name() == b.node_def(j).name()) { if (!EqualFunctionNodeDef( a.node_def(i), b.node_def(j), - strings::StrCat("Function ", a.signature().name()), diff)) { + absl::StrCat("Function ", a.signature().name()), diff)) { return false; } found = true; @@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", expected: has node '", a.node_def(i).name(), - "' got: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); } return false; } @@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", got: has node '", b.node_def(i).name(), - "' expected: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); } return false; } @@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, auto it = actual_index.find(expected_function.signature().name()); if (it == actual_index.end()) { if (diff) { - *diff = strings::StrCat("Did not find expected function '", - expected_function.signature().name(), "'"); + *diff = absl::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); } return false; } @@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, if (!actual_index.empty()) { if (diff != nullptr) { - *diff = strings::StrCat("Found unexpected function '", - actual_index.begin()->second->signature().name(), - "'"); + *diff = + absl::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), "'"); } return false; } @@ -420,10 +420,9 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_recv"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -440,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_send"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -682,8 +680,8 @@ std::vector> GraphEdges(const Graph& graph) { for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( - strings::StrCat(edge->src()->name(), ":", edge->src_output()), - strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + absl::StrCat(edge->src()->name(), ":", edge->src_output()), + absl::StrCat(edge->dst()->name(), ":", edge->dst_input())); } std::sort(edges.begin(), edges.end()); return edges; diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 676f71a75aede2a7720ae0c8a579d64cc184509a..8212956adfeca263334e3d0d954f69e13c1ba28d 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -14,6 +14,7 @@ cc_library( hdrs = ["graphcycles.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 805bbc62c1e2e877de87ab8faf3d60b829743df8..756377bd9502d7172b29f317c471963d1dee09a9 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,7 +34,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,7 +44,7 @@ namespace { typedef std::unordered_set NodeSet; template struct VecStruct { - typedef gtl::InlinedVector type; + typedef absl::InlinedVector type; }; template using Vec = typename VecStruct::type; diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index c37b6112cc8a92047d495d057f59e2281710e678..5dcf754969f1709bd0e211b456bc634766239980 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -21,6 +21,18 @@ limitations under the License. namespace tensorflow { +// PRE_PLACEMENT passes: + +// from +// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc +// FunctionalizeControlFlowPass: 27 +// +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (XlaIf/XlaWhile). Following passes must +// handle those FunctionDef correctly. + +// POST_REWRITE_FOR_EXEC passes: REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5b6692f523658749f7ef48f9d7d89e97d4ce8b09..07c5b2318851ed506711b9ee00c66fe680a3afd8 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -28,18 +28,6 @@ cc_library( ], ) -cc_library( - name = "parallel_check_op_flags", - srcs = ["parallel_check_op_flags.cc"], - hdrs = ["parallel_check_op_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "xla_device_flags", srcs = ["xla_device_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc deleted file mode 100644 index a61694b49407b923b7c83f35e903ef49a2175f0e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static ParallelCheckOpFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new ParallelCheckOpFlags; - flags->parallel_check_failfast = true; - flags->parallel_check_atol = "1e-5"; - flags->parallel_check_rtol = "1e-5"; - flag_list = new std::vector({ - Flag("parallel_check_failfast", &flags->parallel_check_failfast, - "Fail immediately on first parallel-check comparison error."), - Flag("parallel_check_atol", &flags->parallel_check_atol, - "Absolute error tolerance for parallel-check comparison."), - Flag("parallel_check_rtol", &flags->parallel_check_rtol, - "Relative error tolerance for parallel-check comparison."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h deleted file mode 100644 index 156a2a2a71097631e24d154b102cd9b85a990b3a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ - -// Legacy flags for the XLA bridge's parallel_check_op module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// parallel_check_op module. -void AppendParallelCheckOpFlags(std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// parallel_check_op module. -typedef struct { - bool parallel_check_failfast; // Fail immediately on first parallel-check - // comparison error. - string parallel_check_atol; // Absolute error tolerance for parallel-check - // comparison. - string parallel_check_rtol; // Relative error tolerance for parallel-check - // comparison. -} ParallelCheckOpFlags; - -// Return a pointer to the ParallelCheckOpFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -ParallelCheckOpFlags* GetParallelCheckOpFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4e4abade3278089a1c7f8fdee46a34b8ce503651..e6cc6e52ae537c23d18dc2d3fb94b40a5d23b1a5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -444,7 +443,7 @@ Status FindCompilationCandidates( !registration->requires_compilation) { const OpDef* op_def; TF_RETURN_IF_ERROR( - OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // We need to be able to constant fold the nodes in // compile_time_const_nodes given constant inputs (required by XLA) and @@ -617,7 +616,7 @@ Status MarkForCompilationPass::Run( } static string RatioToString(int numerator, int denominator) { - return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -626,14 +625,14 @@ static void VLogClusteringSummary(const Graph& g) { return; } - std::map cluster_name_to_size; - std::map> + std::map cluster_name_to_size; + std::map> cluster_name_to_op_histogram; - std::map unclustered_op_histogram; + std::map unclustered_op_histogram; int clustered_node_count = 0; for (Node* n : g.nodes()) { - absl::optional cluster_name = GetXlaClusterForNode(*n); + absl::optional cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; cluster_name_to_size[*cluster_name]++; @@ -650,7 +649,7 @@ static void VLogClusteringSummary(const Graph& g) { << RatioToString(clustered_node_count, g.num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { - StringPiece cluster_name = cluster_name_size_pair.first; + absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " << RatioToString(size, g.num_nodes()); @@ -670,14 +669,15 @@ static void VLogClusteringSummary(const Graph& g) { } struct EdgeInfo { - StringPiece node_name; - absl::optional cluster_name; + absl::string_view node_name; + absl::optional cluster_name; - StringPiece GetClusterName() const { + absl::string_view GetClusterName() const { return cluster_name ? *cluster_name : "[none]"; } - std::pair> AsPair() const { + std::pair> AsPair() + const { return {node_name, cluster_name}; } @@ -686,19 +686,21 @@ static void VLogClusteringSummary(const Graph& g) { } }; - using EdgeInfoMap = std::map>; + using EdgeInfoMap = std::map>; EdgeInfoMap incoming_edge_infos; EdgeInfoMap outgoing_edge_infos; - std::set cluster_names_to_print; + std::set cluster_names_to_print; for (const Edge* e : g.edges()) { const Node* from = e->src(); - absl::optional from_cluster_name = GetXlaClusterForNode(*from); + absl::optional from_cluster_name = + GetXlaClusterForNode(*from); const Node* to = e->dst(); - absl::optional to_cluster_name = GetXlaClusterForNode(*to); + absl::optional to_cluster_name = + GetXlaClusterForNode(*to); if (to_cluster_name == from_cluster_name) { continue; @@ -721,9 +723,9 @@ static void VLogClusteringSummary(const Graph& g) { VLOG(2) << " [none]"; } - auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name, + auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, const EdgeInfoMap& edge_info_map, - StringPiece desc) { + absl::string_view desc) { auto it = edge_info_map.find(cluster_name); if (it != edge_info_map.end()) { VLOG(2) << " " << it->second.size() << " " << desc << " edges"; @@ -737,7 +739,7 @@ static void VLogClusteringSummary(const Graph& g) { } }; - for (StringPiece cluster_name : cluster_names_to_print) { + for (absl::string_view cluster_name : cluster_names_to_print) { VLOG(2) << " ** Cluster " << cluster_name; print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, "incoming"); @@ -966,7 +968,7 @@ Status MarkForCompilationPass::RunImpl( string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3c133b95915ea88e0bf99dbb8661452..c59770a4c8d4a5cb8508a928677f34aeb3d6acf5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -633,7 +634,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); @@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) { EXPECT_EQ(clusters["shape"], ""); } +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index a8f09bfa5034e020fe3448d8ecfe0f70605e14d2..10fc9e85d927ffe2416d6d9e6dfd24b286fbf1a0 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,7 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -30,7 +34,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { - absl::optional from_cluster = GetXlaClusterForNode(*n); + absl::optional from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; } @@ -79,7 +83,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, // Check if `dst` is in a different cluster, unclustered, or about to be // partially declustered (here we rely on the post-order traversal order). // If yes, decluster `n` to avoid the device-to-host memcpy. - absl::optional dst_cluster = + absl::optional dst_cluster = result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); if (from_cluster != dst_cluster) { CHECK(result->insert(n).second); @@ -91,15 +95,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, } Status PartiallyDeclusterNode(Graph* graph, Node* n) { - StringPiece cluster_name = *GetXlaClusterForNode(*n); - gtl::InlinedVector out_edges_to_clone; + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + absl::InlinedVector out_edges_to_clone; for (const Edge* out_edge : n->out_edges()) { if (out_edge->IsControlEdge()) { continue; } Node* dst = out_edge->dst(); - absl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*dst); if (dst_cluster_name != cluster_name) { out_edges_to_clone.push_back(out_edge); } @@ -108,7 +113,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { CHECK(!out_edges_to_clone.empty()) << n->DebugString(); NodeDef ndef = n->def(); - ndef.set_name(strings::StrCat(n->name(), "/declustered")); + ndef.set_name(absl::StrCat(n->name(), "/declustered")); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); @@ -128,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -} // namespace -Status PartiallyDeclusterPass::Run( - const GraphOptimizationPassOptions& options) { - // NB! In this pass we assume the only XLA-auto-clusterable operations that - // may have side effects are resource variable operations so we don't cluster - // those. The pass will have to be updated if this assumption becomes - // invalid. - - Graph* graph = options.graph->get(); +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been // visited before producers. std::vector post_order; GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); + /*edge_filter=*/NotBackedge); gtl::FlatSet nodes_to_partially_decluster; - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); if (VLOG_IS_ON(3)) { for (Node* n : post_order) { @@ -168,10 +190,133 @@ Status PartiallyDeclusterPass::Run( } nodes_to_partially_decluster.clear(); - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); return Status::OK(); } + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +Status MustCompileNode(const Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *result = false; + } else { + *result = registration->requires_compilation; + } + + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h index 6949b5028ee55e182b27589f9a9711dad7839e86..cfc4ddb5630bec91d6942c983ce1efae3a735c43 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -20,34 +20,11 @@ limitations under the License. namespace tensorflow { -// Clones nodes from within a cluster to outside the cluster if profitable. +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: // -// Today this only clones to avoid device-to-host copies, but in the future we -// may consider other reasons to clone. For instance, we convert this: -// -// ..... -// | -// v -// A_Clustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// to: -// -// ..... -// | | -// | +-------------+ -// | | -// v v -// A_Clustered A_Unclustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// where the ===> arrow has a hostmem source and destination and would entail a -// device to host copy if the source and destination were not in the same XLA -// cluster. +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. class PartiallyDeclusterPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index f61a955c222dd7ce11a177cd54bb8851a5400496..35872daa658810707c12fb5020ee6d913167946b 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/memory/memory.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -82,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr* graph) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } GraphOptimizationPassOptions opt_options; @@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr* graph) { return pass.Run(opt_options); } -const Node* FindNodeByName(const Graph& graph, const string& name) { - for (const Node* node : graph.nodes()) { +Node* FindNodeByName(const Graph& graph, const string& name) { + for (Node* node : graph.nodes()) { if (node->name() == name) { return node; } @@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { "ClusteredProducer0/declustered"); EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); } + +void AddToCluster(absl::Span nodes, + absl::string_view cluster_name) { + for (Node* n : nodes) { + n->AddAttr(kXlaClusterAttr, string(cluster_name)); + } +} + +TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + auto graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT, + ops::Placeholder::Attrs{}); + Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b); + Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul); + + Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()}, + "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({reshape.node()}, "cluster_0"); + AddToCluster({shape.node()}, "cluster_1"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + // This is needed to register the XLA_GPU device. + std::vector devices; + TF_ASSERT_OK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); + + // Scope::ToGraph loses the assigned device name since it goes through + // GraphDef/NodeDef which does not have a field for the assigned device name. + Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + n->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:XLA_GPU:0"); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); + + for (Device* d : devices) { + delete d; + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 1ba4a5ef7399111e512da8c4966f5899ed828b17..56e35c0059124015266ffabdf583c8724c8e0908 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -165,7 +165,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair; string ResourceOpToString(const ResourceOp& resource_op) { - return strings::StrCat( + return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); } @@ -257,11 +257,11 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); - return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); + return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { - return strings::StrCat( + return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); } diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 4f2fabd658330b8ab182e13e02ed0bca41641e46..f85121ca27ad3da918315f93b28e9000dfd65e67 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" @@ -52,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, }; string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); + absl::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -64,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + absl::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -186,7 +187,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } -absl::optional GetXlaClusterForNode(const Node& node) { +absl::optional GetXlaClusterForNode(const Node& node) { const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); if (attr_value == nullptr) { return absl::nullopt; @@ -209,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } +void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } + Status AdjustCycleDetectionGraphForResourceOps( const Graph* graph, const FunctionLibraryDefinition* flib_def, const std::function& resource_ops_to_ignore, diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index b0439a63ca6476b6b1d63e65308712270381dd9f..ba218f3315d2607c47342fdade0403678faa2362 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -47,11 +47,14 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. -absl::optional GetXlaClusterForNode(const Node& node); +absl::optional GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index ef6b0e67d3c4007f86dc7eef89cacb4cea98fc15..3aa9e9c7ed2dd3b7480f40e868c6b07192b68294 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() { string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string result = sig.name; for (const auto& a : sig.arg_types) { - strings::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + absl::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.DebugString()); + absl::StrAppend(&result, "; ", v.DebugString()); } return result; } @@ -259,7 +259,7 @@ Status XlaCompilationCache::CompileImpl( const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); - VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); + VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() @@ -310,7 +310,7 @@ Status XlaCompilationCache::CompileImpl( // cache eviction. mutex_lock entry_lock(entry->mu); if (!entry->compiled) { - VLOG(1) << "Compilation cache miss for signature: " + VLOG(2) << "Compilation cache miss for signature: " << SignatureDebugString(signature); tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f31879a2bc517d8b05e129cf0777196d0ee4dc79..51797def041d5d223d22fb28408ec91290a1400d 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -148,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { } const DeviceAttributes attrs = Device::BuildDeviceAttributes( - strings::StrCat(name_prefix, "/device:", device_name, ":", - device_ordinal), + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), - strings::StrCat("device: ", device_name, " device")); + absl::StrCat("device: ", device_name, " device")); device->reset( new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ee07c5c9643ef1119b9077326c1cf7c83930e90c..af83c792e5e11d8596c521c6a3aed332a1f42e5b 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -203,7 +203,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { @@ -339,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 2e7445340cbaf788bfd06260f4376596895231c1..df824212948ac96a5df5228cecd9a8c864bbec9a 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -57,7 +57,7 @@ class XlaTransferManager { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, @@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 13da5d2f948df671df6d0d80687321eaaa923943..49c85826829fb44d58f10e084f8d757d65bf1882 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -198,33 +198,33 @@ class XlaAssignVariableOp : public AsyncOpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ - GeneratorDatasetOp); \ + data::GeneratorDatasetOp); \ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ .Device(DEVICE) \ .HostMemory("buffer_size") \ .HostMemory("input_dataset") \ .HostMemory("handle"), \ - PrefetchDatasetOp); \ + data::PrefetchDatasetOp); \ \ REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ - IteratorHandleOp); \ + data::IteratorHandleOp); \ REGISTER_KERNEL_BUILDER( \ Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ - MakeIteratorOp); \ + data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ - AnonymousIteratorHandleOp); \ + data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ - IteratorGetNextOp); \ + data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ - IteratorGetNextSyncOp); \ + data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorToStringHandleOp); \ + data::IteratorToStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorFromStringHandleOp); \ + data::IteratorFromStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 07cfab615157650aea0e15cdafa8c9b0925f9e5f..bc0db558d8d0b7c666efcfac5c4926144b830380 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -326,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 4c9bb2e27b0ca3c83848be7fdf189fdbad89cee5..d95da63405889dfd0c279b17789a2195072c7277 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -122,7 +122,7 @@ class XlaTensor { std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); mutex mu_; }; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 34defe1c7ade687a7524390cee78657e1a27f5b4..2176eaebe4dd61a23be26ed32f68ffbee8b64c53 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -581,6 +581,7 @@ tf_xla_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -1103,6 +1104,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/strings", ], ) @@ -1196,7 +1198,7 @@ tf_xla_py_test( tf_xla_py_test( name = "xla_ops_test", - size = "small", + size = "medium", srcs = ["xla_ops_test.py"], disabled_backends = ["cpu_ondemand"], deps = [ diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index df0f21471a1c67e69e037f6409bcab1297d3399d..058576b3d4b695209952158769162bb24e7ccfce 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -56,7 +56,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -98,7 +98,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. @@ -140,7 +140,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): # TODO: test fails for float16 due to excessive precision requirements. if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 04f3b3ef4905984b0432a536c3b1c275738ede17..0af74c2d8f243d8f5ccf1373e0706039cc8ef041 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase): Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) @@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. @@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a74b284dc7d18bcbb07c178c664c1e5..0839fb123e83960e198eac2bed769afbdd517889 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 9222db4b7ebf020c8cee1c0af81e05129fb33c4d..c61965b97fc142ce452cf28def8c937f692d2f84 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class MatrixBandPartTest(xla_test.XLATestCase): +class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase): - def _testMatrixBandPart(self, dtype, shape): - with self.cached_session(): - batch_shape = shape[:-2] - mat = np.ones(shape).astype(dtype) - batch_mat = np.tile(mat, batch_shape + [1, 1]) - for lower in -1, 0, 1, shape[-2] - 1: - for upper in -1, 0, 1, shape[-1] - 1: - band_np = mat - if lower >= 0: - band_np = np.triu(band_np, -lower) - if upper >= 0: - band_np = np.tril(band_np, upper) - if batch_shape: - band_np = np.tile(band_np, batch_shape + [1, 1]) - - placeholder = array_ops.placeholder(dtype) - with self.test_scope(): - band = array_ops.matrix_band_part( - placeholder, - constant_op.constant(lower, dtype=dtypes.int32), - constant_op.constant(upper, dtype=dtypes.int32)) - feed_dict = {placeholder: batch_mat} - self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) - - def testMatrixBandPart(self): + @parameterized.parameters( + { + 'batch_shape': [], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [2,], + 'rows': 7, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 1, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 2, + 'cols': 7 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 1 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 2 + }, + { + 'batch_shape': [1, 3, 2], + 'rows': 7, + 'cols': 7 + }, + ) + def testMatrixBandPart(self, batch_shape, rows, cols): for dtype in self.float_types: - for batch_shape in [[], [2,], [1, 3, 2]]: - for rows in 1, 2, 7: - for cols in 1, 2, 7: - self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + with self.cached_session(): + mat = np.ones(batch_shape + [rows, cols]).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, rows - 1: + for upper in -1, 0, 1, cols - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 0faf0fd8edf355838ccf42f1d6de20ac01faa3db..bddda6f30245d4b8281a77783ec9922d61bd3883 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" @@ -61,7 +63,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); + return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array kAllXlaTypes = { @@ -107,11 +108,12 @@ class OpTestBuilder { // Sets an attribute. template - OpTestBuilder& Attr(StringPiece attr_name, T&& value); + OpTestBuilder& Attr(absl::string_view attr_name, T&& value); // Overload needed to allow {...} expressions for value. template - OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list value); + OpTestBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); // Adds nodes that executes the operator under test on 'device' to 'graphdef'. // If 'use_jit' is true, marks the operator under test to be compiled by XLA. @@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); return *this; } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, std::initializer_list value) { Attr>(attr_name, std::move(value)); return *this; @@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, NodeDef* test_def = graphdef->add_node(); *test_def = node_def_; - test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_name(absl::StrCat(name_prefix, "_op_under_test")); test_def->set_device(device); AddDefaultsToNodeDef(*op_def, test_def); if (use_jit) { @@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_input_", i); + string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_output_", i); + string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -726,11 +728,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, template string Str(T x) { - return strings::StrCat(x); + return absl::StrCat(x); } template <> string Str(complex64 x) { - return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); + return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { - return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", - Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), - "atol = ", atol, " rtol = ", rtol, - " tol = ", atol + rtol * Abs(Tx(i)))); + return errors::InvalidArgument( + absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)), + " vs. ", Str(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), ". x = ", x.DebugString(), "y = ", y.DebugString())); } @@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, double rtol) { if (a.dtype() != b.dtype()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", DataTypeString(b.dtype()))); } if (!a.IsSameSize(b)) { - return errors::InvalidArgument(strings::StrCat( - "Tensors have different shapes: ", a.shape().DebugString(), " and ", - b.shape().DebugString())); + return errors::InvalidArgument( + absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(), + " and ", b.shape().DebugString())); } switch (a.dtype()) { @@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } string cpu_device = - LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - strings::StrCat("test", num_tests_, "_expected"), cpu_device, + absl::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, &expected_inputs, &expected_fetches); if (!status.ok()) { @@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } NodeDef* node_def; - status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"), test_device, tf_xla_test_use_jit, &graph, &node_def, &test_inputs, &test_fetches); if (!status.ok()) { diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py index 84c67779400f7a800bd88abc32d95058a6c0904d..96e0b074754032dd64c479b5e587b664ff066e2b 100644 --- a/tensorflow/compiler/tests/reshape_op_test.py +++ b/tensorflow/compiler/tests/reshape_op_test.py @@ -33,7 +33,7 @@ class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase): ('64_bit_index', dtypes.int64)) def testBasic(self, index_dtype): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): shape = constant_op.constant([3, 2], dtype=index_dtype) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index b2f026df6c0c28fcbceaa0493871bc12c2d23b1f..1e600c44e9af66994686359eb0e1a1e52bea93fd 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -97,9 +98,9 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) - PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, - xla_data_pb2.PrecisionConfigProto.HIGH, - xla_data_pb2.PrecisionConfigProto.HIGHEST) + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT, + xla_data_pb2.PrecisionConfig.HIGH, + xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) def testConv(self, precision): @@ -120,7 +121,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.conv( lhs, @@ -151,7 +152,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.rhs_batch_dimensions.append(0) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.dot_general( lhs, @@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._assertOpOutputMatchesExpected( lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + def testDynamicSlice(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.dynamic_slice, + args=(np.arange(1000, + dtype=np.int32).astype(dtype).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3, 2])), + expected=np.array( + np.array([[[573, 574], [583, 584], [593, 594]], + [[673, 674], [683, 684], [693, 694]]]), + dtype=dtype)) + + def testDynamicSliceWithIncorrectStartIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7]), np.array([2, 3, 4])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^start_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and start_indices has shape \[2\].*')) + + def testDynamicSliceWithIncorrectSizeIndicesShape(self): + with self.test_session() as session: + with self.test_scope(): + output = xla.dynamic_slice( + np.arange(1000, dtype=np.int32).reshape([10, 10, 10]), + np.array([5, 7, 3]), np.array([2, 3])) + with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error: + session.run(output) + self.assertRegexpMatches( + invalid_arg_error.exception.message, + (r'^size_indices must be a vector with length equal to input rank, ' + r'but input rank is 3 and size_indices has shape \[2\].*')) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0797b2cb17f5aae4080f339a201b44d69bbb2187..d549e7bb59905160a5599fea83667951a60e674d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -76,6 +76,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", ":tf2xla_proto", ":tf2xla_util", ":xla_compiler", @@ -188,9 +189,9 @@ cc_library( deps = [ ":common", ":dump_graph", - ":functionalize_control_flow", ":host_compute_metadata_proto", ":sharding_util", + ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", @@ -283,6 +284,7 @@ cc_library( deps = [ ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -291,6 +293,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -358,6 +361,7 @@ tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ + ":side_effect_util", ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", @@ -369,6 +373,7 @@ tf_cc_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:core_cpu_internal", @@ -433,6 +438,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -474,6 +480,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -501,11 +508,23 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], ) +cc_library( + name = "functionalize_control_flow_pass_registration", + srcs = [ + "functionalize_control_flow_pass_registration.cc", + ], + deps = [ + ":functionalize_control_flow", + ], + alwayslink = 1, +) + cc_library( name = "functionalize_while", srcs = [ @@ -515,6 +534,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", "//tensorflow/compiler/jit:union_find", @@ -525,6 +545,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -539,6 +560,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", @@ -609,11 +631,10 @@ cc_library( srcs = ["resource_operation_table.cc"], hdrs = ["resource_operation_table.h"], deps = [ - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -630,3 +651,12 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "side_effect_util", + srcs = ["side_effect_util.cc"], + hdrs = ["side_effect_util.h"], + deps = [ + "//tensorflow/core:core_cpu", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index e8673d77903bd5a1a85412e9dfa86437f73d56bc..922ae7c79a1d3e0ad55bc2858a45cd6be1dc1117 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -26,8 +26,9 @@ namespace tensorflow { // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_args, - std::vector* compile_time_const_nodes) { + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + std::function edge_filter) { // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set metadata_ops = { "Rank", @@ -45,8 +46,7 @@ Status BackwardsConstAnalysis(const Graph& g, } Status status; - auto visit = [&status, &metadata_ops, compile_time_const_nodes, - compile_time_const_args](Node* node) { + auto visit = [&](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -59,13 +59,13 @@ Status BackwardsConstAnalysis(const Graph& g, int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - if (compile_time_const_args) { - (*compile_time_const_args)[index] = true; + if (compile_time_const_arg_indices) { + (*compile_time_const_arg_indices)[index] = true; } return; } for (const Edge* pred : node->in_edges()) { - if (!pred->IsControlEdge()) { + if (!pred->IsControlEdge() && edge_filter(*pred)) { (*compile_time_const_nodes)[pred->src()->id()] = true; } } @@ -88,7 +88,8 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && - edge->dst_input() < name_range->second.second) { + edge->dst_input() < name_range->second.second && + edge_filter(*edge)) { (*compile_time_const_nodes)[edge->src()->id()] = true; } } @@ -97,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g, // Post-order traversal visits nodes in reverse topological order for an // acyclic graph. - DFS(g, {}, visit); + DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, + [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); return status; } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index af57e5a4033248e3fd32dabeda252c4ca0a44050..49b3c6d413c6b637fa825bf182be7cc36e49b6c8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -32,9 +32,13 @@ namespace tensorflow { // // The ids of the nodes in `graph` that must be constant are returned in // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. -Status BackwardsConstAnalysis(const Graph& graph, +// +// Only propagate const-ness along edges for which `edge_filter` returns true. +Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_arg_indices, - std::vector* compile_time_const_nodes); + std::vector* compile_time_const_nodes, + std::function edge_filter = + [](const Edge& e) { return true; }); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 24616c01c7e54b2e8662457ca6af23a0bc563e08..380c6a7e23da92d949b26876836b999bf6406c6c 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) { string filename = name; if (count > 0) { - strings::StrAppend(&filename, "_", count); + absl::StrAppend(&filename, "_", count); } - strings::StrAppend(&filename, ".pbtxt"); + absl::StrAppend(&filename, ".pbtxt"); return filename; } @@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile( << proto_type << ": " << status; return "(unavailable)"; } - string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); status = WriteTextProto(Env::Default(), filepath, proto); if (!status.ok()) { LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index b5667ca0d3ba35bea9da2d702b5b49fb38fe6f02..db256e577a1f3dd38e04d102f60182023b9d43b2 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -34,30 +34,16 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { -string DebugString(const CondStateMap::CondNode& node) { - return node.ToString(); -} - // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { - return strings::StrCat(tensor.node->name(), ":", tensor.index); -} - -string DebugString(CondStateMap::CondId cond_state) { - if (cond_state == nullptr || cond_state->empty()) return "[]"; - return strings::StrCat( - "[", - absl::StrJoin(*cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), - "]"); + return absl::StrCat(tensor.node->name(), ":", tensor.index); } string Branch_Name(BranchType b) { @@ -73,6 +59,24 @@ string Branch_Name(BranchType b) { } } +string DebugString(StateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "{}"; + using value_type = StateMap::CondState::value_type; + return absl::StrCat( + "{", + absl::StrJoin(*cond_state, ", ", + [](string* output, const value_type& pred_branch) { + const OutputTensor& pred = pred_branch.first; + const BranchType& branch = pred_branch.second; + if (branch == BranchType::kNeither) + absl::StrAppend(output, "d"); + else + absl::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); + }), + "}"); +} + // Returns the predicate of a switch. Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { const Edge* pred_edge; @@ -86,64 +90,65 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { return Status::OK(); } -CondStateMap::CondNode::CondNode(Type type, Node* switch_node, - BranchType branch) - : type(type), branch(branch) { - if (type == Type::kSwitch) { - TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); - } -} - -string CondStateMap::CondNode::ToString() const { - switch (type) { - case Type::kSwitch: - return strings::StrCat("s(", DebugString(predicate), ",", - Branch_Name(branch), ")"); - case Type::kMerge: - return "m"; - case Type::kDead: - return "d"; - } +Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { + const Edge* val_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); + *val = OutputTensor(val_edge->src(), val_edge->src_output()); + return Status::OK(); } -bool CondStateMap::CondNode::operator==(const CondNode& other) const { - if (type != Type::kSwitch) return type == other.type; - return type == other.type && predicate == other.predicate && - branch == other.branch; +bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, + const OutputTensor& rhs) const { + return (lhs.node->id() < rhs.node->id()) || + (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index); } -bool CondStateMap::CondNode::operator!=(const CondNode& other) const { - return !(*this == other); -} +struct CondStateLess { + bool operator()(const StateMap::CondState::value_type& lhs, + const StateMap::CondState::value_type& rhs) const { + if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first)) + return true; + if (lhs.first.node->id() == rhs.first.node->id() && + lhs.first.index == rhs.first.index) + return lhs.second < rhs.second; + return false; + } +}; -CondStateMap::CondStateMap(Graph* graph) { +StateMap::StateMap(Graph* graph) { node_to_condid_map_.resize(graph->num_node_ids()); + node_to_ancestorid_map_.resize(graph->num_node_ids()); // Initialize the dead state (empty state is designated with a nullptr). - dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); + dead_id_ = GetCondId( + {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)}); } -bool CondStateMap::IsDead(CondStateMap::CondId id) const { - return id == dead_id_; -} +bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; } -bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { - return id == nullptr; -} +bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondNode& item) const { - return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), - hash()(item.branch)), - hash()(item.type)); +size_t StateMap::Hash::operator()(const StateMap::CondState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second)); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second))); + } + return h; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondState& vec) const { - if (vec.empty()) return 0; - size_t h = (*this)(vec.front()); - auto it = vec.begin(); - for (++it; it != vec.end(); ++it) { - h = Hash64Combine(h, (*this)(*it)); +size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = hash()(*it); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, hash()(*it)); } return h; } @@ -155,8 +160,8 @@ struct CondArgNode { : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); + return absl::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); } Node* src; @@ -167,58 +172,76 @@ struct CondArgNode { using CondArgNodes = std::vector; string DebugString(const CondArgNodes& nodes) { - return strings::StrCat( + return absl::StrCat( "[", absl::StrJoin(nodes, ", ", [](string* output, const CondArgNode& node) { - strings::StrAppend(output, node.ToString()); + absl::StrAppend(output, node.ToString()); }), "]"); } -CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { +StateMap::CondId StateMap::LookupCondId(const Node* node) const { if (node->id() < node_to_condid_map_.size()) return node_to_condid_map_[node->id()]; - return added_node_mapping_.at(node->id()); + return added_node_condid_mapping_.at(node->id()); } -CondStateMap::CondId CondStateMap::GetUniqueId( - const CondStateMap::CondState& state) { +StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { if (state.empty()) return nullptr; return &*condstate_set_.insert(state).first; } -const CondStateMap::CondState& CondStateMap::LookupState( - const Node* node) const { - return *LookupId(node); -} - -void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { +void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { if (node->id() < node_to_condid_map_.size()) node_to_condid_map_[node->id()] = id; else - added_node_mapping_[node->id()] = id; + added_node_condid_mapping_[node->id()] = id; +} + +StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { + if (node->id() < node_to_ancestorid_map_.size()) + return node_to_ancestorid_map_[node->id()]; + return added_node_ancestorid_mapping_.at(node->id()); +} + +StateMap::AncestorId StateMap::GetAncestorId( + const StateMap::AncestorState& state) { + if (state.empty()) return nullptr; + return &*ancestorstate_set_.insert(state).first; +} + +void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { + if (node->id() < node_to_ancestorid_map_.size()) + node_to_ancestorid_map_[node->id()] = id; + else + added_node_ancestorid_mapping_[node->id()] = id; } -void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } +void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } -string CondStateMap::CondStateToString(const Node* node) const { - return CondStateToString(LookupId(node)); +string StateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupCondId(node)); } -string CondStateMap::CondStateToString(CondStateMap::CondId id) const { +string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } +string StateMap::AncestorStateToString(const Node* node) const { + if (auto id = LookupAncestorId(node)) return NodesToString(*id); + return "{}"; +} + FunctionalizeCond::FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : cond_state_map_(graph), library_(library), graph_(graph) {} + : state_map_(graph), library_(library), graph_(graph) {} // Class representing the merge/switch nodes that will become a conditional. class Conditional { public: Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map); + StateMap* cond_state_map); // Adds merge node that is part of this conditional. Status AddMerge(Node* m); @@ -247,6 +270,10 @@ class Conditional { // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); + // Adds a switch node along the edge and rewire the edge to go via the switch. + Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); + // Internal name of conditional. The name is based on the first merge node // added. string name() const; @@ -255,7 +282,7 @@ class Conditional { FunctionalizeCond* parent_; // Mapping between nodes and their cond state. - CondStateMap* cond_state_map_; + StateMap* state_map_; // The predicate of the conditional. OutputTensor predicate_; @@ -292,8 +319,8 @@ class Conditional { }; Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map) - : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + StateMap* cond_state_map) + : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {} Status Conditional::AddMerge(Node* m) { merges_.insert(m); @@ -343,7 +370,7 @@ Status Conditional::BuildArgumentNodes() { for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_Arg", arg_count), + NodeBuilder(absl::StrCat("_Arg", arg_count), FunctionLibraryDefinition::kArgOp) .Attr("T", dtype) .Attr("index", arg_count) @@ -397,6 +424,35 @@ Status Conditional::BuildArgumentNodes() { return Status::OK(); } +Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph) { + // Previously we had edge: + // src:src_output ---- edge ----> dst:dst_input + // post this we have (in graph) + // src:src_output --> switch --- new_edge --> dst:dst_input + + // TODO(jpienaar): One could keep a map caching the extra switch nodes added + // to avoid adding another switch to feed a value for which a switch was + // already added. + Node* switch_node; + Node* src = edge->src(); + int src_output = edge->src_output(); + TF_RETURN_IF_ERROR( + NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")), + "Switch") + .Input(src, src_output) + .Input(const_cast(predicate_.node), predicate_.index) + .Finalize(graph, &switch_node)); + state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src)); + state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src)); + + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(switch_node, static_cast(branch), dst, dst_input); + return AddSwitch(switch_node); +} + Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { @@ -405,16 +461,16 @@ Status Conditional::ExtractBodies(Graph* graph) { } auto find_branch = [&](const Edge* e) { - const auto& id = cond_state_map_->LookupId(e->src()); + const auto& id = state_map_->LookupCondId(e->src()); return IsSwitch(e->src()) ? BranchType(e->src_output()) - : cond_state_map_->FindBranchOf(id, predicate_); + : state_map_->FindBranchOf(id, predicate_); }; std::array, 2> stacks; VLOG(5) << "Merges: " << NodesToString(merges_); for (Node* m : merges_) { VLOG(5) << "For merge: " << m->DebugString() << " " - << cond_state_map_->CondStateToString(m); + << state_map_->CondStateToString(m); for (auto e : m->in_edges()) { if (e->IsControlEdge()) continue; BranchType branch = find_branch(e); @@ -422,7 +478,8 @@ Status Conditional::ExtractBodies(Graph* graph) { branch == BranchType::kElseBranch) << "Error: " << e->src()->name() << " is not on either then or else branch (" << Branch_Name(branch) - << ")."; + << ") for predicate " << DebugString(predicate_) << " [" + << DebugString(state_map_->LookupCondId(e->src())) << "]."; Node* src = e->src(); if (IsSwitch(src)) { // Switch node outputs and dependencies are handled separately. @@ -456,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) { if (IsMerge(dst)) continue; Node* src = e->src(); - auto dst_id = cond_state_map_->LookupId(dst); - auto src_id = cond_state_map_->LookupId(src); + auto dst_id = state_map_->LookupCondId(dst); + auto src_id = state_map_->LookupCondId(src); if (dst_id != src_id) { if (e->IsControlEdge()) { external_control_outputs_.push_back(e->src()); @@ -480,8 +537,11 @@ Status Conditional::ExtractBodies(Graph* graph) { } } - // Copying incomming edges to dst node. - for (const Edge* e : n->in_edges()) { + // Copying incomming edges to dst node. Iterate over a copy of the edges + // as they could be mutated during iteration. + std::vector in_edges(n->in_edges().begin(), + n->in_edges().end()); + for (const Edge* e : in_edges) { Node* src = e->src(); // Skip src/dst node. if (!src->IsOp()) continue; @@ -494,8 +554,8 @@ Status Conditional::ExtractBodies(Graph* graph) { } // Verify input is from the same context. - auto src_id = cond_state_map_->LookupId(src); - auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = state_map_->LookupCondId(src); + auto dst_id = state_map_->LookupCondId(dst); if (IsMerge(dst) || src_id == dst_id) { // TODO(jpienaar): The merge case can be more strict. if (node_map.at(src->id()) == nullptr) { @@ -506,18 +566,25 @@ Status Conditional::ExtractBodies(Graph* graph) { external_control_inputs_.push_back(src); } else { // This shouldn't happen, this means we have an external data input - // not entering via a switch node. Work around this for constant - // nodes as some constant nodes are inserted without the required - // control context dominance. + // not entering via a switch node. Work around this by for + // * constant nodes copy them; + // * non-constant nodes, insert a switch along the edge; if (IsConstant(src)) { node_map.at(src->id()) = output->CopyNode(src); } else { - return errors::InvalidArgument( - "Graph contains node ", FormatNodeForError(*src), - " that feeds into node ", FormatNodeForError(*dst), - " but these nodes are in different control contexts (", - DebugString(src_id), " vs ", DebugString(dst_id), - " (detected during in edge testing)"); + StateMap::CondState state = *dst_id; + state.erase(predicate_); + if (state_map_->GetCondId(state) == src_id) { + TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph)); + continue; + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } } } @@ -572,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) { Status Conditional::BuildIfNode(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); - NodeDefBuilder builder(name(), "If"); + NodeDefBuilder builder(name(), "If", library); const string branch_name[] = {"else_branch", "then_branch"}; for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); @@ -580,8 +647,8 @@ Status Conditional::BuildIfNode(Graph* graph, int64 id = ++sequence_num; NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_if_", - branch_name[branch_index], "_", id)); + body_name.set_name( + absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id)); VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] << "): " @@ -639,7 +706,8 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build If node"; NodeDef if_def; TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + TF_ASSIGN_OR_RETURN(if_node_, + parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); return Status::OK(); } @@ -699,7 +767,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { Status Conditional::BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(1) << "Build If and replace merge nodes " << name(); + VLOG(1) << "Build If and replace merge nodes " + << NodesToString(this->merges_); if (replaced_) return Status::OK(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); @@ -719,7 +788,6 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) cond_state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -732,31 +800,7 @@ Status Conditional::BuildAndReplace(Graph* graph, string Conditional::name() const { CHECK(!merges_.empty()); - return strings::StrCat((*merges_.begin())->name(), "_if"); -} - -bool CondStateMap::ScopeIn(CondStateMap::CondId id, - CondStateMap::CondId* scope) { - if (id == nullptr) { - *scope = nullptr; - return true; - } - CondState state; - for (const CondNode& node : *id) { - if (node.type == CondNode::Type::kSwitch) { - state.push_back(node); - } - if (node.type == CondNode::Type::kMerge) { - if (state.empty()) { - return false; - } - DCHECK(state.back().type == CondNode::Type::kSwitch && - state.back().branch == BranchType::kBoth); - state.pop_back(); - } - } - *scope = GetUniqueId(state); - return true; + return absl::StrCat((*merges_.begin())->name(), "_if"); } Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, @@ -765,25 +809,35 @@ Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") .Input(if_node, port) .Finalize(graph_, &id)); - cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); + state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); return Status::OK(); } StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, - const Node* replacee) { + const Node* replacee, + const OutputTensor& predicate) { Status status; Node* ret = graph_->AddNode(def, &status); TF_RETURN_IF_ERROR(status); - CondStateMap::CondState state = cond_state_map_.LookupState(replacee); - state.pop_back(); VLOG(1) << "Adding If for " << replacee->name(); - cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + StateMap::CondId id = state_map_.LookupCondId(replacee); + if (id) { + StateMap::CondState state = *id; + state.erase(predicate); + state_map_.ResetCondId(ret, state_map_.GetCondId(state)); + } else { + state_map_.ResetCondId(ret, nullptr); + } + + state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee)); + return ret; } Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { VLOG(2) << "Propagating update state for " << replacee->name() << " " - << cond_state_map_.CondStateToString(replacee); + << state_map_.CondStateToString(replacee); // Redo topological sort as the order could have changed. // TODO(jpienaar): The original topological order could also be updated // dynamically if needed. @@ -801,10 +855,10 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { if (changed.find(*it) != changed.end()) { // Update the node state. Node* n = *it; - CondStateMap::CondId old_state = cond_state_map_.LookupId(n); - cond_state_map_.ResetId(n, nullptr); + StateMap::CondId old_state = state_map_.LookupCondId(n); + state_map_.ResetCondId(n, nullptr); TF_RETURN_IF_ERROR(DetermineCondState(n)); - if (cond_state_map_.LookupId(n) != old_state) { + if (state_map_.LookupCondId(n) != old_state) { for (auto out : n->out_nodes()) if (out->IsOp()) changed.insert(out); } @@ -825,127 +879,44 @@ BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { return BranchType::kNeither; } -CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - CondId lhs_scope; - CondId rhs_scope; - bool could_determine_scope = ScopeIn(lhs, &lhs_scope); - could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); - if (!could_determine_scope) return kIncomparable; - - // Returns whether a contains b. - auto contains = [&](CondId a, CondId b) { - // Handle empty states. - if (a == nullptr && b != nullptr) return true; - if (a == nullptr && b == nullptr) return true; - if (a != nullptr && b == nullptr) return false; - - if (a->size() > b->size()) return false; - auto a_it = a->begin(); - auto b_it = b->begin(); - while (a_it != a->end()) { - if (*a_it != *b_it) { - if (!(a_it->predicate == b_it->predicate)) return false; - BranchType mb = MeetBranch(a_it->branch, b_it->branch); - if (mb != b_it->branch) return false; - } - ++a_it; - ++b_it; - } - return true; - }; - - bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); - bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); - if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; - if (lhs_contains_rhs) return kLhsContainsRhs; - if (rhs_contains_lhs) return kRhsContainsLhs; - return kIncomparable; -} - -BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { +BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { if (IsEmpty(id)) return BranchType::kNeither; - absl::optional b; const CondState& nodes = *id; - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == predicate) { - if (b.has_value()) { - b = MeetBranch(*b, it->branch); - } else { - b = it->branch; - } - if (*b == BranchType::kNeither) { - LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); - } - } - } - return b.has_value() ? *b : BranchType::kNeither; + auto it = nodes.find(predicate); + if (it == nodes.end()) return BranchType::kNeither; + return it->second; } -StatusOr FunctionalizeCond::JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - VLOG(4) << "Joining src=" << DebugString(src) << " [" << src +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { + VLOG(5) << "Joining src=" << DebugString(src) << " [" << src << "] and dst=" << DebugString(dst) << " [" << dst << "]"; - if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; + if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst; // Nothing to do if the CondState is the same. if (src == dst) return src; - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); - switch (result) { - case CondStateMap::kIncomparable: - return errors::InvalidArgument( - "Graph contains node with inputs predicated on incompatible " - "predicates: ", - DebugString(src), " and ", DebugString(dst)); - case CondStateMap::kEqual: - // If both respect the same predicates, propagate the longer constraint. - if ((src != nullptr && dst == nullptr) || - (src != nullptr && dst != nullptr && src->size() > dst->size())) - return src; - else - return dst; - case CondStateMap::kLhsContainsRhs: - // src contains dst, so dst is already more restrictive. - return dst; - case CondStateMap::kRhsContainsLhs: - // dst contains src, so src is more restrictive. - return src; - } -} - -StatusOr -FindThenElseSwitchForPredicate(const OutputTensor& pred, - CondStateMap::CondId id) { - for (auto it = id->begin(); it != id->end(); ++it) { - // Along every path one there can be only one instance of a then or else - // switch for a given predicate, so return once found. - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == pred && - (it->branch == BranchType::kThenBranch || - it->branch == BranchType::kElseBranch)) - return it; + StateMap::CondState both = *src; + for (const auto& kv : *dst) { + auto it = both.find(kv.first); + if (it == both.end()) { + both.insert(kv); + } else { + if (it->second != kv.second) { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } + } } - return errors::Internal("Unable to find then/else branch with predicate ", - DebugString(pred), " for ", DebugString(id)); + return state_map_.GetCondId(both); } -StatusOr FunctionalizeCond::JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { +StatusOr FunctionalizeCond::JoinCondStatesMerge( + Node* merge, StateMap::CondId src, StateMap::CondId dst) { // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a // disjunction of the states along the different input edges. For a merge that @@ -956,91 +927,56 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( // followed by s(p, both). VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); - if (cond_state_map_.IsEmpty(dst)) return src; - - if (cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; - - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) - << "Illegal merge inputs from outer scope: src=" << DebugString(src) - << " dst=" << DebugString(dst); - auto src_it = src_scope->begin(); - auto dst_it = dst_scope->begin(); - - // Find branch divergent condition. - OutputTensor pred; - while (src_it != src_scope->end() && dst_it != dst_scope->end()) { - if (*src_it != *dst_it) { - VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " - << DebugString(*dst_it); - if (!(src_it->predicate == dst_it->predicate)) { - return errors::InvalidArgument( - "Unable to find common predicate which holds for one input " - "but not the other of the merge node."); - } - pred = src_it->predicate; - break; - } - ++src_it; - ++dst_it; - } - - if (pred.node == nullptr) - return errors::InvalidArgument("Unable to determine predicate for merge."); - - TF_ASSIGN_OR_RETURN(auto div_src_it, - FindThenElseSwitchForPredicate(pred, src)); - TF_ASSIGN_OR_RETURN(auto div_dst_it, - FindThenElseSwitchForPredicate(pred, dst)); - TF_RET_CHECK(*div_src_it != *div_dst_it); - - CondStateMap::CondState result; - // Populate result with the longest/most restrictive path up to the divergent - // node. For example, if the one input is `[switch(pred:0, then)]` and the - // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created - // in gradient of cond test), then the resultant state here should be - // `[switch(pred:0, both), merge, switch(pred:0, both)]`. - if (std::distance(src->begin(), div_src_it) > - std::distance(dst->begin(), div_dst_it)) { - result.assign(src->begin(), std::next(div_src_it)); + if (state_map_.IsEmpty(dst)) return src; + + if (state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst)) return dst; + + std::vector diff; + StateMap::CondState merged; + std::set_symmetric_difference(src->begin(), src->end(), dst->begin(), + dst->end(), std::back_inserter(diff), + CondStateLess()); + std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(), + std::inserter(merged, merged.begin()), CondStateLess()); + + // Update mapping from merge node to predicate. + if (diff.size() == 2) { + auto pred = diff[0].first; + bool different_branches = (diff[0].second != diff[1].second) && + (diff[0].second == BranchType::kThenBranch || + diff[0].second == BranchType::kElseBranch) && + (diff[1].second == BranchType::kThenBranch || + diff[1].second == BranchType::kElseBranch); + if (!(pred == diff[1].first) || !different_branches) + return errors::InvalidArgument( + "Unable to determine predicate for merge node"); + merge_to_predicate_[merge] = pred; } else { - result.assign(dst->begin(), std::next(div_dst_it)); + return errors::InvalidArgument( + "Merge of two inputs that differ on more than one predicate ", + DebugString(src), " and ", DebugString(dst)); } - result.back().branch = BranchType::kBoth; - return cond_state_map_.GetUniqueId(result); + + return state_map_.GetCondId(merged); } -CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { +StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { Node* src = e->src(); - CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); - if (IsMerge(src)) { - CondStateMap::CondState state; - if (id != nullptr) state = *id; - state.emplace_back(CondStateMap::CondNode::Type::kMerge); - return cond_state_map_.GetUniqueId(state); - } + StateMap::CondId id = state_map_.LookupCondId(e->src()); + + // Dead nodes only propagate dead state. + if (state_map_.IsDead(id)) return id; + if (IsSwitch(src)) { - CondStateMap::CondState state; + StateMap::CondState state; if (id != nullptr) state = *id; - if (e->IsControlEdge()) { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType::kBoth); - } else { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType(e->src_output())); + OutputTensor predicate; + TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); + if (!e->IsControlEdge()) { + state[predicate] = BranchType(e->src_output()); } - return cond_state_map_.GetUniqueId(state); + return state_map_.GetCondId(state); } return id; } @@ -1049,22 +985,21 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) - return Status::OK(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK(); int data_inputs = 0; for (auto e : dst->in_edges()) { Node* src = e->src(); VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(src); + << state_map_.CondStateToString(src); if (!src->IsOp()) continue; if (!e->IsControlEdge()) ++data_inputs; - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } // Incomplete Merge nodes are not supported. @@ -1076,27 +1011,20 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondState(Node* dst) { - // The logic for the merge and non-merge case differ: for non-merge it is - // the most restrictive CondState, while for merge nodes the - // resultant state is less restrictive than either. - if (IsMerge(dst)) { - TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); - } else { - // Handle non-merge join. - for (auto e : dst->in_edges()) { - VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(dst); - Node* src = e->src(); - if (!src->IsOp()) continue; - - // Joining the state between the current and propagated state. - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", - FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); - } +Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " + << state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } return Status::OK(); } @@ -1104,8 +1032,7 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) - return Status::OK(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1113,8 +1040,8 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { Node* src = e->src(); // Handle merge with dead state. - const auto& src_id = cond_state_map_.LookupId(src); - if (!cond_state_map_.IsDead(src_id)) { + const auto& src_id = state_map_.LookupCondId(src); + if (!state_map_.IsDead(src_id)) { non_dead_edge = e; break; } @@ -1124,8 +1051,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } - cond_state_map_.MarkDead(node); - delete_nodes_.push_back(node->id()); + state_map_.MarkDead(node); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { const Edge* oe = *node->out_edges().begin(); @@ -1149,16 +1075,33 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // along one. The checking of predicate is based on the exact predicate // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. + StateMap::CondId dst_id = state_map_.LookupCondId(node); + if (state_map_.IsDead(dst_id)) return Status::OK(); + + BranchType b; OutputTensor pred; TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); - auto dst_id = cond_state_map_.LookupId(node); - BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is - // true/false. - if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return Status::OK(); + // true/false. Consider both the data and predicate to determine if the + // node is redundant (skipping over identity node). + b = state_map_.FindBranchOf(dst_id, pred); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) { + OutputTensor val; + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + while (IsIdentity(val.node)) { + TF_RETURN_IF_ERROR(val.node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + } + b = state_map_.FindBranchOf(dst_id, val); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + } - VLOG(5) << "Redundant switch " << node->name(); + VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " + << DebugString(dst_id); const Edge* value_edge; TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); Node* val_node = value_edge->src(); @@ -1171,20 +1114,19 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { graph_->RemoveEdge(e); if (switch_branch == Graph::kControlSlot) { if (IsMerge(dst_node)) { - auto id_or = - JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + auto id_or = JoinCondStatesMerge(dst_node, dst_id, + state_map_.LookupCondId(dst_node)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst_node)); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } else { auto id_or = - JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node)); TF_RETURN_IF_ERROR(id_or.status()); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } } else if (BranchType(switch_branch) != b) { - cond_state_map_.MarkDead(dst_node); - delete_nodes_.push_back(dst_node->id()); + state_map_.MarkDead(dst_node); continue; } graph_->AddEdge( @@ -1195,37 +1137,103 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondStates( - std::vector rev_topo_order) { +Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { // The state that is propagated along the given edge. for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { Node* dst = *it; TF_RETURN_IF_ERROR(DetermineCondState(dst)); + TF_RETURN_IF_ERROR(DetermineAncestorState(dst)); if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); - VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) + << " @ " << state_map_.AncestorStateToString(dst); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("it"); } return Status::OK(); } -void FunctionalizeCond::DeleteReachableNodes() { +Status FunctionalizeCond::DetermineAncestorState(Node* dst) { + StateMap::AncestorId id = nullptr; + StateMap::AncestorState state; + + auto insert = [&](StateMap::AncestorId id, Node* src) { + auto other_id = state_map_.LookupAncestorId(src); + if (other_id != id && other_id != nullptr) { + state.insert(other_id->begin(), other_id->end()); + } + if (IsSwitch(src) || IsMerge(src)) { + state.insert(src); + } + return state_map_.GetAncestorId(state); + }; + + // Compute the union of all the switch/merge nodes that affects the input of + // dst. + for (auto e : dst->in_edges()) { + Node* src = e->src(); + id = insert(id, src); + } + state_map_.ResetAncestorId(dst, id); + return Status::OK(); +} + +void FunctionalizeCond::DeleteReachableAndDeadNodes( + const std::vector& switch_ids, const std::vector& merge_order) { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been // removed. + std::deque delete_nodes; std::vector deleted(graph_->num_node_ids(), false); // Don't try to delete source or sink nodes. deleted[graph_->kSourceId] = true; deleted[graph_->kSinkId] = true; - while (!delete_nodes_.empty()) { - int d_id = delete_nodes_.front(); - delete_nodes_.pop_front(); + + // All remaining Switch nodes are not reachable from a Merge node and + // removed. This is to account for dead Switch nodes. + for (int s_id : switch_ids) { + Node* s = graph_->FindNodeId(s_id); + if (s == nullptr) continue; + for (const Edge* e : s->out_edges()) { + // Control outputs of switch nodes (which are unconditionally executed if + // the switch is) are not removed as they need not be part of a + // conditional. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[s_id] = true; + graph_->RemoveNode(s); + } + + // All merge nodes should have been transformed at this point and we remove + // them from the graph here. + for (Node* m : merge_order) { + for (const Edge* e : m->out_edges()) { + // Similar to control outputs of switch nodes don't remove control + // outputs of merge nodes. + // TODO(jpienaar): Check cases where output edges still exist here vs + // being removed in AddOutputEdges. + if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id()); + } + deleted[m->id()] = true; + graph_->RemoveNode(m); + } + + // Enqueue all the dead nodes. + for (Node* n : graph_->nodes()) { + if (state_map_.IsDead(state_map_.LookupCondId(n))) { + delete_nodes.push_back(n->id()); + } + } + + while (!delete_nodes.empty()) { + int d_id = delete_nodes.front(); + delete_nodes.pop_front(); if (deleted[d_id]) continue; Node* d = graph_->FindNodeId(d_id); // Switch and Merge nodes could have been deleted already. if (d == nullptr) continue; for (const Edge* e : d->out_edges()) { - delete_nodes_.push_back(e->dst()->id()); + delete_nodes.push_back(e->dst()->id()); } deleted[d_id] = true; graph_->RemoveNode(d); @@ -1239,16 +1247,8 @@ void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { inner_to_outer_merge_order.reserve(merge_order->size()); for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { Node* merge = *it; - CondStateMap::CondId id = cond_state_map_.LookupId(merge); - int depth = 0; - for (auto cond_node_it = id->begin(); cond_node_it != id->end(); - ++cond_node_it) { - if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && - (cond_node_it->branch == BranchType::kThenBranch || - cond_node_it->branch == BranchType::kElseBranch)) { - ++depth; - } - } + StateMap::CondId id = state_map_.LookupCondId(merge); + int depth = id != nullptr ? id->size() : 0; inner_to_outer_merge_order.emplace_back(depth, merge); } std::stable_sort( @@ -1271,10 +1271,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // determine deeper equivalence). We shall refer to this structure as the // CondState; // 3. Sort the merge nodes by nesting depth; - // 4. Extract merge nodes together that have the same CondState and whose - // input nodes have the same state from the innermost to the outermost into - // IfOps; Note: In the above only nodes paths that converge to a merge node - // will be considered for removal. + // 4. Extract merge nodes together that have the same CondState and + // AncestorState from the innermost to the outermost into IfOps; + // Note: In the above only nodes that feed into a merge node will be + // considered for functionalization. // Perform a DFS over the graph and // * Determine the reverse topological order of the nodes (there should be no @@ -1306,50 +1306,46 @@ Status FunctionalizeCond::FunctionalizeInternal() { return Status::OK(); } - TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); - - if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); + TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); + if (VLOG_IS_ON(4)) DumpGraphWithCondState("id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Extract from innermost out. - for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { - Node* merge = *it; - auto id = cond_state_map_.LookupId(merge); - if (cond_state_map_.IsDead(id)) continue; - - // Construct a Conditional with the predicate of the merge (which is the - // last entry of the CondState for the merge) and this as parent. - DCHECK(id->back().predicate.node != nullptr); - Conditional cond(id->back().predicate, this, &cond_state_map_); - TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - - // Find all merge nodes with the same CondId. This is done repeatedly as - // the CondId can change due replaced conditionals. E.g., the one branch - // could previously have had a conditional nested in it, and so would have - // had CondState with sub-state [switch(p,b),m] (where p is some predicate), - // post removing the nested conditional that sub-state would no longer be - // path of the propagated state along that path. - auto end = merge_order.end(); - for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; - ++merge_candidate_it) { - auto merge_candidate_it_id = - cond_state_map_.LookupId(*merge_candidate_it); - if (merge_candidate_it_id != id) continue; - TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + // Cluster merge nodes by CondId and AncestorId in order of nesting. + using ClusterPair = std::pair; + std::deque> merge_clusters; + std::map merge_cluster_index; + for (Node* merge : merge_order) { + auto cond_id = state_map_.LookupCondId(merge); + if (state_map_.IsDead(cond_id)) continue; + + ClusterPair key = + std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto idx = merge_cluster_index.find(key); + if (idx == merge_cluster_index.end()) { + merge_cluster_index[key] = merge_clusters.size(); + merge_clusters.push_back({merge}); + } else { + merge_clusters[idx->second].emplace_back(merge); } + } + // Extract the conditionals from inner most to outer most. Extracting from + // innermost to outermost enables the extraction pass to stop once it + // encounters a Switch node instead of having to keep track of Switch/Merge + // nodes seen. + for (const auto& cluster : merge_clusters) { + // Construct a Conditional with the predicate of the merge. + Conditional cond(merge_to_predicate_.at(cluster.front()), this, + &state_map_); + for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); } - // All remaining Switch nodes are not reachable from a Merge node and - // removed. This is to account for dead Switch nodes. - for (int s_id : switch_ids) delete_nodes_.push_back(s_id); - for (Node* m : merge_order) delete_nodes_.push_back(m->id()); - DeleteReachableNodes(); + DeleteReachableAndDeadNodes(switch_ids, merge_order); return Status::OK(); } @@ -1359,11 +1355,14 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { for (Node* n : graph_->nodes()) { n->ClearAttr(kCondGroupDebugAttr); - n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + n->AddAttr(kCondGroupDebugAttr, + absl::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", name), *graph_, library_); + absl::StrCat("functionalize_cond_", name), *graph_, + library_); } Status FunctionalizeCond::Functionalize(Graph* graph, diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 86436011c6ebdc608a5811a1b0d6a10015d405bd..189980894073b1da1a12d1c284536336eb920900 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,59 +43,53 @@ enum class BranchType { kNeither = 3, }; -// CondStateMap is responsible for mapping from each graph Node to a CondState, -// where each CondState is the array of CondNodes (corresponding to switch, -// merge or dead states) as described below. For efficiency, this class interns -// the CondState, so that CondState equality comparisons are simply pointer +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer // comparisons. -class CondStateMap { +class StateMap { public: - explicit CondStateMap(Graph* graph); - - // Represents an entry in the CondState. An entry can either be the - // switch (along with predicate), merge, or dead: - // * switch node indicates a node that is executed along a branch with the - // given predicate - a branch can be then, else or both; - // * merge node indicates that the node is executed as output of a merge; - // * dead indicates that this node can never be executed; - struct CondNode { - enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; - - CondNode(Type type, Node* switch_node = nullptr, - BranchType branch = BranchType::kNeither); - - string ToString() const; - bool operator==(const CondNode& other) const; - bool operator!=(const CondNode& other) const; - - // Type of node. - Type type; - - // Predicate and branch, only used when type is kSwitch. - OutputTensor predicate; - BranchType branch; + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; }; - // A node in the graph is executed when multiple conditions hold. The order - // represents the nesting of the predicates that hold and is used when - // extracting the nested conditionals. - using CondState = std::vector; + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map; // Every unique ID is mapped to a CondState. using CondId = const CondState*; + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + // Returns the CondId for a given node. - CondId LookupId(const Node* node) const; + CondId LookupCondId(const Node* node) const; // Returns the unique CondId for CondState. - CondId GetUniqueId(const CondState& state); - - // Returns the CondState for a Node. - // REQUIRES: node has a non-empty CondState. - const CondState& LookupState(const Node* node) const; + CondId GetCondId(const CondState& state); // Resets the CondId for a given node. - void ResetId(const Node* node, CondId id); + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); // Marks `node` as dead. void MarkDead(const Node* node); @@ -103,45 +97,30 @@ class CondStateMap { // Determine branch execution of CondState. BranchType FindBranchOf(CondId id, OutputTensor predicate) const; - // Enum to represent whether one cond flow state contains another. - enum ContainsResult { - kIncomparable, - kEqual, - kLhsContainsRhs, - kRhsContainsLhs - }; - - // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., - // [(p,t)] contains [(p,t), (r,t)]. - ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); - // Returns textual representation of node's CondState. string CondStateToString(const Node* node) const; string CondStateToString(CondId id) const; + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; // Returns whether the cond state is the empty state. bool IsEmpty(CondId id) const; - // Computes the predicates that have to hold for a node to execute and returns - // whether it was possible to determine the predicates that must hold. `scope` - // is populated with these predicates. Scope differs from state in that it - // does not include merge and both nodes. - bool ScopeIn(CondId id, CondId* scope); - private: - // Hash for CondNode and CondState. - struct CondHash { - size_t operator()(const CondNode& item) const; - size_t operator()(const CondState& vec) const; + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; }; // Set to keep track of unique CondStates. // Pointers to the entries in the unordered set are used as identifiers: // unordered_set guarantees that the pointers remain the same. - std::unordered_set condstate_set_; + std::unordered_set condstate_set_; // Mapping from Node id to CondId. std::vector node_to_condid_map_; @@ -150,7 +129,12 @@ class CondStateMap { // from Node id in the original graph to the CondId, but there will be nodes // added to the original graph (such as If nodes) whose CondState needs to be // tracked too. - std::unordered_map added_node_mapping_; + std::unordered_map added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set ancestorstate_set_; + std::vector node_to_ancestorid_map_; + std::unordered_map added_node_ancestorid_mapping_; // Identifier of the dead flow state. The empty flow state is represented with // a nullptr. @@ -173,7 +157,8 @@ class FunctionalizeCond { // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. - xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee); + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); // Propagates the state of a newly inserted node. Status PropagateUpdatedState(const Node* replacee); @@ -185,35 +170,42 @@ class FunctionalizeCond { FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); // Performs the actual cond functionalization. Iterate over groups of merge - // nodes (linked by common predicate & CondIds of the incomming edges), - // from innermost to outermost, and extract into If nodes. + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. Status FunctionalizeInternal(); // Returns the forward flow state propagated along edge `e`. - // This may modify cond_state_map_. - CondStateMap::CondId StateAlongEdge(const Edge* e); + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); - // Determines the CondState of all the nodes in the given vector where - // the input is expected in reverse topological order. - // This populates the cond_state_map_. - Status DetermineCondStates(std::vector rev_topo_order); + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + Status DetermineStates(std::vector rev_topo_order); // Determine the CondState for a given node using the incomming edges // to the node. Note: it is expected that this node's CondState is only // determined once its input's CondState is. - Status DetermineCondState(Node* dst); + Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } // Helper functions for DetermineCondState. + Status DetermineCondStateNonMerge(Node* dst); Status DetermineCondStateMerge(Node* dst); - // Helper functions for DetermineCondStates. Determines the dst node's - // CondState by joining the src and dst's CondState where either - // the dst node is a merge or not. - // These may modify cond_state_map_. - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + xla::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + Status DetermineAncestorState(Node* dst); // Checks if a merge node is redundant and if so removes it from the graph. Status RemoveRedundantMerge(Node* node); @@ -225,15 +217,18 @@ class FunctionalizeCond { // nesting depth. void SortMergeNodes(std::vector* merge_order); - // Deletes all nodes in/consumers of `delete_nodes_`. - void DeleteReachableNodes(); + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& switch_ids, + const std::vector& merge_order); - // Member used to unique the CondState to a unique CondId and keep track of - // CondState/CondId per Node. - CondStateMap cond_state_map_; + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; - // Nodes to be deleted. - std::deque delete_nodes_; + // Mapping from merge nodes to predicate. + std::unordered_map merge_to_predicate_; FunctionLibraryDefinition* library_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index a27f8893925855f536801a8a68855b82ac07462d..b0aabd63bbda784b3b7103a438ce025eea0cd93b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -37,28 +37,23 @@ class FunctionalizeCondTest : public ::testing::Test { flib_def_.get())); } - CondStateMap::CondId GetUniqueId( - const CondStateMap::CondStateMap::CondState& state) { - return fc_->cond_state_map_.GetUniqueId(state); + StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { + return fc_->state_map_.GetCondId(state); } - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesNonMerge(src, dst); - } - - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesMerge(src, dst); + string GetString(const StateMap::StateMap::CondId id) { + return fc_->state_map_.CondStateToString(id); } - bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { - return fc_->cond_state_map_.ScopeIn(ff, scope); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); } - CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + xla::StatusOr JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesMerge(n, src, dst); } FunctionDefLibrary fdef_lib_; @@ -69,50 +64,6 @@ class FunctionalizeCondTest : public ::testing::Test { namespace { -TEST_F(FunctionalizeCondTest, ScopeIn) { - Tensor pred_tensor(DT_BOOL, TensorShape()); - pred_tensor.flat().setZero(); - Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); - Tensor val_tensor(DT_INT32, TensorShape()); - val_tensor.flat().setZero(); - Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); - - { - CondStateMap::CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope; - ASSERT_TRUE(ScopeIn(id, &scope)); - ASSERT_TRUE(id == scope); - } - - CondStateMap::CondState empty; - { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope_1; - ASSERT_TRUE(ScopeIn(id, &scope_1)); - ASSERT_TRUE(scope_1 == GetUniqueId(empty)); - ASSERT_TRUE(id != scope_1); - - ss.clear(); - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - id = GetUniqueId(ss); - CondStateMap::CondId scope_2; - ASSERT_TRUE(ScopeIn(id, &scope_2)); - - ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == - CondStateMap::ContainsResult::kLhsContainsRhs); - } -} - TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor pred_tensor(DT_BOOL, TensorShape()); pred_tensor.flat().setZero(); @@ -120,22 +71,18 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor val_tensor(DT_INT32, TensorShape()); val_tensor.flat().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); + Node* m = test::graph::Merge(graph_.get(), val, val); - CondStateMap::CondId empty = GetUniqueId({}); - - CondStateMap::CondId then_branch; + StateMap::CondId then_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch)); then_branch = GetUniqueId(ss); } - CondStateMap::CondId else_branch; + StateMap::CondId else_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch)); else_branch = GetUniqueId(ss); } @@ -144,39 +91,14 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { EXPECT_TRUE(errors::IsInvalidArgument(status)); // Merge between then and else branch. - auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); TF_EXPECT_OK(joined_or.status()); - CondStateMap::CondId joined = joined_or.ValueOrDie(); + StateMap::CondId joined = joined_or.ValueOrDie(); // Merge between then branch and both branch. auto t = JoinCondStatesNonMerge(then_branch, joined); // Note: this is OK in terms of constraint predication, but TF_EXPECT_OK(t.status()); - - // Post merge the propagated forward flow state has an additional merge. - CondStateMap::CondId post_merge; - { - CondStateMap::CondState ss; - ss = *joined; - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - post_merge = GetUniqueId(ss); - } - - t = JoinCondStatesNonMerge(post_merge, joined); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(joined == t.ValueOrDie()); - - // No predicate that results in two paths predicated on different conditions - // merge. - t = JoinCondStatesMerge(post_merge, joined); - EXPECT_FALSE(t.ok()); - - // Post the merge we are effectively in the root scope and merging should - // result in the more restrictive post merge state. - t = JoinCondStatesNonMerge(post_merge, empty); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(post_merge == t.ValueOrDie()); } } // namespace diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5932be4e525dec11a8f3c59bb85e0449e76e79c0..f792c520329039c8da63d07ea27fa1c403f5c67d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,16 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -68,4 +73,146 @@ Status FunctionalizeControlFlow(Graph* graph, return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); } +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + std::map* canonicalized_name_to_new_name) { + // Convert the function to Graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + const FunctionDef& fdef = body->fdef; + + // If any node has associated functions, functionalize them first. + // Gather nodes with associated functions first, because rewriting those nodes + // might involve node deletion/addition. Avoid modifying nodes while iterating + // it. + std::vector>> + nodes_to_associated_functions; + for (auto* n : body->graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, flr); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (auto iter : nodes_to_associated_functions) { + Node* n = iter.first; + auto associated_functions = iter.second; + for (auto& associated_function : associated_functions) { + string name = associated_function.func_name(); + string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + auto iter = canonicalized_name_to_new_name->find(canonicalized_name); + string new_name; + if (iter != canonicalized_name_to_new_name->end()) { + // If we already functionalized this function, skip functionalization + // but still rewrite the node. + new_name = iter->second; + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } + // Notice that if "n" is a function call, RewriteAssociatedFunction() will + // delete it and create a new node instead, making "n" an invalid pointer. + // That's fine because in that case, associated_functions will only have + // one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + body->graph, n, fld, associated_function, new_name)); + } + } + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *body->graph, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), + *body->graph, fld); + } + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); + + // Copy signature and ret from original FunctionDef. + *functionalized_fdef.mutable_signature() = fdef.signature(); + *functionalized_fdef.mutable_ret() = fdef.ret(); + functionalized_fdef.mutable_signature()->set_name(new_func_name); + + // Add rewritten FunctionDef into library. + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + + return ret_status; +} + +Status FunctionalizeControlFlowPass::Run( + const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph, + options.flib_def); + } + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime( + /*device_mgr=*/nullptr, options.session_options->env, + TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions())); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + // Find XLA compile ops and its corresponding FunctionDef. + static std::map* kNodeTypeToFunctionAttrMapping = + new std::map{ + {"TPUCompile", "function"}, + {"XlaLaunch", "function"}, + }; + std::map canonicalized_name_to_new_name; + for (Node* n : graph->nodes()) { + auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); + if (it == kNodeTypeToFunctionAttrMapping->end()) { + continue; + } + const string func_attr = it->second; + if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) != + kNodeTypeToFunctionAttrMapping->end()) { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func)); + VLOG(2) << "Graph has node " << n->type_string() + << ". Corresponding function: " << func.name(); + string new_func_name = options.flib_def->UniqueFunctionName( + absl::StrCat(func.name(), "_f15n_")); + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func.name(), new_func_name, func.attr(), options.flib_def, flr, + &canonicalized_name_to_new_name)); + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } + } + + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph, + options.flib_def); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 55600f2a8b5302cef26b9be4ccd0f8804476a17a..ba99205640ccdc83a3a4d50e3ec474907894a835 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -32,6 +33,14 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, Graph* graph, FunctionLibraryDefinition* library); +// This pass looks at the graph and all associated FunctionDefs, and turns +// traditional control flow structure (Switch/Merge/etc.) into functional +// control flow structure (If/While). +class FunctionalizeControlFlowPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc new file mode 100644 index 0000000000000000000000000000000000000000..a10a9d0499457bbc0383ea3a8c678f153e21894b --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +namespace tensorflow { + +// This pass is required for some AOT backends and all JIT backends, so this +// file exists as a separate lib and will be linked to both AOT and JIT. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27, + FunctionalizeControlFlowPass); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c068a4110c0bb14282379eb7a3cbdae4e80ddbd6..c3841f996f801e855da75b23f01d41674ec51c4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" @@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, then_fn, - else_fn, {DT_INT32}); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, + then_fn, else_fn); auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); - // TODO(jpienaar): Create wrapper for IfOp. - for (NodeDef& n : *expected.mutable_node()) { - if (n.op() == "XlaIf") n.set_op("If"); - } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) { Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, NameAttrList* body) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaWhile") { + if (node.op() == "While") { const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); *cond = *result; @@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, return Status::OK(); } } - return errors::NotFound("No XlaWhile node found in graph"); + return errors::NotFound("No While node found in graph"); } // Graph: @@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_ASSERT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list{x, y}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); GraphDef expected; @@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); - auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), - std::initializer_list{zero, y, x, var}, - outer_cond_fn, outer_body_fn); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto one_j = ops::Const( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = - ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); auto one_outer = ops::Const( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 924fcdd9cd72a6472e0b2748680f2552fa65ec79..54cebc61778ba051b9c903f8e2c3696cec69843a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -42,7 +42,7 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { const char* const kRetValOp = "_Retval"; NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); + ret_def.set_name(absl::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); return AddNodeDefToGraph(ret_def, graph); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 61940e3586c59ffc660eaac8f8d035fbbbdfeffd..582b49d5116acc651fb6242b5c2b9aeeac269532 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -43,13 +43,12 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template string NodesToString(const T& nodes) { - return strings::StrCat("{", - absl::StrJoin(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + absl::StrAppend(output, node->name()); + }), + "}"); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 6e3c4b0e0f695f0073f2c8aa1a4b342e39ea4be5..7c3ad448ef546dd1ab2640a57d7d1d73ca3768ad 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { @@ -132,7 +134,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, StatusOr BuildArgNode(Graph* graph, DataType type, int index) { const char* const kArgOp = "_Arg"; NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); + NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); @@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } } - // Builds the condition and body functions. + // Builds the condition and body functions. Notice that we call + // FunctionalizeCond() on cond_graph and body_graph because we might have + // unfunctionalized "if" in cond_graph and body_graph. Functionalize them + // before they are encapsulated in FunctionDef. std::unique_ptr cond_graph; TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + FixupSourceAndSinkEdges(cond_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library)); DataTypeVector arg_types; std::unique_ptr body_graph; TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + FixupSourceAndSinkEdges(body_graph.get()); + TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library)); VLOG(2) << "Frame " << frame->name << " condition: " << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) @@ -487,9 +496,9 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, static std::atomic sequence_num(0LL); int64 id = ++sequence_num; NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + cond_name.set_name(absl::StrCat("_functionalize_cond_", id)); NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); + body_name.set_name(absl::StrCat("_functionalize_body_", id)); FunctionDef cond_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); @@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + NodeDefBuilder builder(frame->loop_cond->name(), "While", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -653,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, // There should be no cycle at this point, since while loops have been removed // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. + // Check that the newly added While nodes don't feed into themselves. for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { + if (node->def().op() == "While") { TF_RETURN_WITH_CONTEXT_IF_ERROR( CheckNodeNotInCycle(node, graph->num_node_ids()), "Functionalizing loop failed."); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 1ed1fb3b021b27be00086b2e71cc9309e3d76049..c019a28e892ff89f559ddbec2360d6caa9c1808f 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -81,7 +80,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, TF_ASSIGN_OR_RETURN(auto literal, client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + LiteralToHostTensor(literal, arg.type, &arg.constant_value)); } else { arg.kind = XlaCompiler::Argument::kParameter; } @@ -127,7 +126,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) << "Not supported node: " << n->DebugString(); params.op_kernel = op_kernel.get(); - gtl::InlinedVector output_attr(n->num_outputs()); + absl::InlinedVector output_attr(n->num_outputs()); params.output_attr_array = output_attr.data(); // tensor_inputs_ is a buffer reused across graph traversal. We clean up and diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index 127562eb23d775f17179cc9ee968ec2255cf3a14..ab7cac7100d39377828462f0dee5df98a7319cc3 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -89,7 +89,7 @@ class GraphCompiler { ScopedStepContainer* step_container_; // A buffer to hold tensor inputs to a node, this is reused across the graph // traversal. - gtl::InlinedVector tensor_inputs_; + absl::InlinedVector tensor_inputs_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4c776fb1781e4d0b0d1fa5f313536eb42d6856bb..46794f7b5070a1a64ac8e16e6a066156a4fa693b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -115,9 +115,6 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", @@ -168,14 +165,11 @@ tf_kernel_library( "//tensorflow/core/kernels:sparse_to_dense_op", "//tensorflow/core/kernels:stack_ops", "//tensorflow/core/kernels:training_ops", - ] + if_mkl( - [ - "//tensorflow/core/kernels:mkl_transpose_op", - ], - [ - "//tensorflow/core/kernels:transpose_op", - ], - ), + "//tensorflow/core/kernels:transpose_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], ) tf_kernel_library( @@ -184,6 +178,7 @@ tf_kernel_library( hdrs = ["while_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", @@ -201,6 +196,7 @@ tf_kernel_library( hdrs = ["if_op.h"], deps = [ "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index edced6bc0e57cfc2b1c62f1e4a010dd316f7d092..a18e04995b5e1e0b0374f7b0edd6f5e114cf994a 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -26,7 +26,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 2e383b1473590403823863f89264e5381d8e8806..182f7c99344845964f7010127718f876ab6e8a44 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -39,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), @@ -88,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel { ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 12b0e38288e8f222ed506a75ec2575f27141c859..e96a1adce43c750314715107b4a1954d4a5b4e40 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index a3389d5b905bf3ee15744ab4fcee193d312e2ae0..4af1e8b44cbbd02d8e3ea5e42d841c92288b5d56 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - VLOG(3) << "DynamicUpdateSliceOp::Compile"; + DataType index_type = ctx->InputType("indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); - DataType index_type = input_type(2); - OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape update_shape = ctx->InputShape(1); - const TensorShape index_shape = ctx->InputShape(2); + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape update_shape = ctx->InputShape("update"); + const TensorShape index_shape = ctx->InputShape("indices"); OP_REQUIRES( ctx, @@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::XlaOp result = - xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2)); + xla::XlaOp result = xla::DynamicUpdateSlice( + ctx->Input("input"), ctx->Input("update"), ctx->Input("indices")); ctx->SetOutput(0, result); } }; REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp); +class DynamicSliceOp : public XlaOpKernel { + public: + explicit DynamicSliceOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType index_type = ctx->InputType("start_indices"); + CHECK(index_type == DT_INT32 || index_type == DT_INT64); + CHECK(index_type == ctx->InputType("size_indices")); + + const TensorShape input_shape = ctx->InputShape("input"); + const TensorShape start_indices_shape = ctx->InputShape("start_indices"); + const TensorShape size_indices_shape = ctx->InputShape("size_indices"); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(start_indices_shape) && + start_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "start_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and start_indices has shape ", + start_indices_shape.DebugString())); + OP_REQUIRES(ctx, + TensorShapeUtils::IsVector(size_indices_shape) && + size_indices_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "size_indices must be a vector with length equal to " + "input rank, but input rank is ", + input_shape.dims(), " and size_indices has shape ", + size_indices_shape.DebugString())); + + std::vector size_indices; + OP_REQUIRES_OK( + ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices)); + xla::XlaOp result = xla::DynamicSlice( + ctx->Input("input"), ctx->Input("start_indices"), size_indices); + ctx->SetOutput(0, result); + } +}; + +REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"), + DynamicSliceOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 6e1dbf5472f0b1eb0abcbe29c553ae926ecf2d8a..56da50f140893c68c8a1556853884720b21c7229 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/if_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } // TODO(b/35949885): There is duplication here with the handling of the @@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { options.resolve_compile_time_constants = false; options.return_updated_values_for_all_resources = true; options.is_entry_computation = false; + options.add_token_input_output = has_token_input_output_; XlaCompiler* compiler = ctx->compiler(); XlaCompiler::CompilationResult then_result; @@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = then_result.input_mapping[i] + 1; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "if" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(b, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); @@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } ctx->SetOutput(i, output_handle); } + if (has_token_input_output_) { + // Set token output for this "if" op. + xla::XlaOp token_output = + xla::GetTupleElement(outputs, output_types_.size()); + auto shape_or = b->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the conditional // bodies. diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h index f9bc98a198a72dcc0594e61971713bf890ce30b6..7783e13a8a5dacc1901392703687230020f82483 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.h +++ b/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel { DataType cond_type_; DataTypeVector input_types_; DataTypeVector output_types_; + bool has_token_input_output_; + std::vector token_input_nodes_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 22a45b2a11e8ecb688f8e773ef4b286eafe68f4f..3d81ae9eb89a80e5b89b180ad77521c5ed15e79d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -78,14 +78,14 @@ class ArgMaxCustomCallOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(xla::ConstantLiteral( - &b, *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); + &b, xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); args.push_back( - xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0(dim))); + xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f6f158a73be42ea2602811ad64a2a2c655dab088..27690c156e4da129ad139f3880bba3a208b5606d 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -138,7 +138,7 @@ xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, int num_dims = num_spatial_dims + 2; int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); - gtl::InlinedVector spatial_dimensions(num_spatial_dims); + absl::InlinedVector spatial_dimensions(num_spatial_dims); for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { spatial_dimensions[spatial_dim] = GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 598248563bb93146e6dea3016822d26b8bf368e7..118f2798d559f43acb7f6394a7337426164325ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -69,7 +69,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << absl::StrJoin(axes, ","); - gtl::InlinedVector bitmap(data_shape.dims(), false); + absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { @@ -103,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(absl::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index c0afccaa5b15dd33fcd016dfdd9bb18e244bf90a..8494864b33a44b03a07e3fea7766285f54074e7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -97,7 +97,7 @@ class ReverseV2Op : public XlaOpKernel { // witnessed_axes is used to ensure that the same axis is not marked to be // reversed multiple times. - gtl::InlinedVector witnessed_axes(x_shape.dims(), false); + absl::InlinedVector witnessed_axes(x_shape.dims(), false); for (int d = 0; d < axes.size(); ++d) { OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 4e0cf99d8e7ff45ed9145981b5e2e637ce4d4e4b..2e0a69b70ef91fb5fee8aac888fdc90517c1356e 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel { // accept legacy scalars, even when they should be forbidden by the graphdef // version. OP_REQUIRES(ctx, dim_shape.num_elements() == 1, - errors::InvalidArgument(strings::StrCat( + errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index b7b4f3a5465c8eea832ef940b7c84a7435edc38c..76b79be6f6f6b5ecbe9edcffb81f2834fdac9a56 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -26,7 +26,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4493539fe34f0ce635fdc58660d4ff90af9c9379..3293c13b21bc4825c83f494b7f2d48a9b3000f9e 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index df91900570107609c0f1c2281faaab8a5e65b98b..ee70f508a9586d5f47bd7bb7670506d4c92b369f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel { xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; - string name = strings::StrCat("Stack: ", stack_name_); + string name = absl::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, TensorShape(), value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 472d4744d7d9cec65645c3259b0c097f0c756bac..2b2e3de64fd0db9d99efa46ecaf7a0fefbae6645 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel { shrink_axis_mask_, &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_end, slice_strides; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { @@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. - gtl::InlinedVector dimensions_to_reverse; + absl::InlinedVector dimensions_to_reverse; xla::PaddingConfig padding_config; for (int i = 0; i < processing_shape.dims(); ++i) { @@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_dims; for (int i = 0; i < begin.size(); ++i) { // TODO(phawkins): implement strides != 1 OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index bb114d1aedd57c7de992a05b37ad53443489596f..94108b764fd32fc77520f9a8ea16065c27e6accf 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -167,7 +167,7 @@ class TensorArrayOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); XlaResource* var; - string name = strings::StrCat("TensorArray: ", tensor_array_name_); + string name = absl::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, shape, value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index f9148b394212777271f9eba51313ee17b19819af..6b303b31d43ce2249a87f25723caf34f84c8387d 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel { std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). - gtl::InlinedVector bits(dims); + absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { const int32 d = perm[i]; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 296518229ebf0ba46717afc4f26d5ae1551c2862..559414eeaa5fec75e5a9d1866baaf738c024cd15 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/while_op.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { cond_name_attr_ = *name_attr; OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); body_name_attr_ = *name_attr; + if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { + has_token_input_output_ = false; + } else { + has_token_input_output_ = !token_input_nodes_.empty(); + } } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { @@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; body_options.is_entry_computation = false; + body_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; cond_options.is_entry_computation = false; + cond_options.add_token_input_output = has_token_input_output_; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); @@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; - if (ctx->input_type(input_num) == DT_RESOURCE) { + if (has_token_input_output_ && i == num_inputs - 1) { + // Set token input for this "while" op. + std::vector token_inputs; + for (const string& node_name : token_input_nodes_) { + auto token_or = compiler->GetNodeToken(node_name); + OP_REQUIRES_OK(ctx, token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + inputs[i] = xla::AfterAll(builder, token_inputs); + } else if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); @@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::GetTupleElement(while_result, i)); } } + if (has_token_input_output_) { + // Set token output for this "while" op. + xla::XlaOp token_output = + xla::GetTupleElement(while_result, ctx->num_outputs()); + auto shape_or = builder->GetShape(token_output); + OP_REQUIRES_OK(ctx, shape_or.status()); + OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()), + errors::FailedPrecondition( + "Token output is not token type: ", + xla::ShapeUtil::HumanString(shape_or.ValueOrDie()))); + OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output)); + } // Updates the values of any resource variables modified by the loop. for (int i = 0; i < body.resource_updates.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h index 67edebabf9f643a919d0f06c228e2d224a49a2af..aeeff40e68f8b778628b9e85bd9b4ddcb73883a5 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.h +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel { private: NameAttrList cond_name_attr_; NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector token_input_nodes_; TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 8848623868091f8d19b1622f23ba23c68689d90d..fecc7c556eb4121b912796e5811632c46769b479 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -84,7 +84,7 @@ class XlaConvOp : public XlaOpKernel { private: xla::ConvolutionDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 2fed53e5c072e1a50e0f07f45357ee86c90f986f..40b15b5579ab9862b9d30df74af9877c98c4aa2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -54,7 +54,7 @@ class XlaDotOp : public XlaOpKernel { private: xla::DotDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); }; diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 9365d203f06d9f1cad320353f43db010d39697af..8597e7f139d8d32b7e08782e70a4ee44d02618f2 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -205,7 +205,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index d8c050d09e871c80e128989c9fbdb57c266b19ed..64f2d781a694393f6fabcd9f443cdb4911921c97 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -28,7 +28,7 @@ namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -96,7 +96,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 6cfccd55530ff40a309673d57d1fe61fc8264316..6edd63a4d3b66c21aa4cce8c9f36eef0dc363cd8 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,11 +43,11 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::DEFAULT); +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index c50a8de33e93a91b1a414146147de48df603eb85..ab3d0a566839343828d176d9a46672824e425613 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -50,7 +50,7 @@ namespace { // l[..., j, j] // return l xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -150,7 +150,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 60cd7ded53fe862f29ca2bb68b175fcd1c89b70c..9a561c34b92ee45059f2a05336e682838f8e36e2 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,9 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp Cholesky( + xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 0a140fa93caec28ebbbd666fd4fa518222ea23a4..6b3f2b6e065b5c99e2d0248237369ecc30188aa5 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -150,7 +150,7 @@ struct QRBlockResult { xla::XlaOp vs; // Shape: [..., m, n] }; xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { + xla::XlaOp a, xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -257,7 +257,7 @@ xla::StatusOr QRBlock( xla::StatusOr ComputeWYRepresentation( xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; @@ -332,7 +332,7 @@ xla::StatusOr ComputeWYRepresentation( // rather than WY transformations. xla::StatusOr QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 8a389fb7b053257adcd2a338dca52445c78381d1..24b537ac8b63b93e734c3d0e335ea455f7d51a54 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -35,8 +35,7 @@ struct QRDecompositionResult { xla::StatusOr QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 37b2240b45b4ae6a587c827cfdfa1096b4e1737e..6524c2a9b1ada632d80edd234272760c2b545cc4 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,9 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks( - xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -216,7 +216,7 @@ xla::XlaOp InvertDiagonalBlocks( dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -245,7 +245,7 @@ xla::XlaOp InvertDiagonalBlocks( xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -346,7 +346,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index ac42a4835295b7cb52697710d738f4728d3983d1..2303234f361e54cd2a0ad495cb03b371bed76877 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,11 +57,10 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp TriangularSolve( + xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index c26784852472061ffead03cfe7431f8b8ba0e555..804671fbc75b0a5a6e04b204822b6f084013cd8b 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -64,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, xla::Literal literal; switch (type) { case xla::U8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::U64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S8: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::S64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F32: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::F64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::C64: - literal = std::move(*xla::LiteralUtil::CreateR0(value)); + literal = xla::LiteralUtil::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -96,12 +96,12 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: - literal = std::move( - *xla::LiteralUtil::CreateR0(static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::F16: - literal = std::move(*xla::LiteralUtil::CreateR0( - static_cast(value))); + literal = + xla::LiteralUtil::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 5300e2c878bf725b65544701eb3fdc6032553491..594ab1dfd0700f47501712183f6efe62d17e15e7 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -47,7 +47,7 @@ xla::StatusOr> XlaWhileLoop( // Build the condition. std::unique_ptr cond_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { auto parameter = xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); @@ -61,7 +61,7 @@ xla::StatusOr> XlaWhileLoop( // Build the body. std::unique_ptr body_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_body")); + builder->CreateSubBuilder(absl::StrCat(name, "_body")); { auto parameter = xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); @@ -84,7 +84,7 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { auto while_cond_fn = [&](absl::Span values, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 115ebf390df6c215680e5982a6ceba546f384af8..f2134bb4495a12b8342961d96f70e7737f816c7d 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -65,7 +65,7 @@ typedef std::function>( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 7dc16b5a46791b81eef2c572736e1a1c7969b203..15f4c38da29507da9e092c1d5725b5f95a81d1b9 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -22,51 +22,61 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace { TEST(LiteralUtil, LiteralToHostTensor) { // int64 literal can only be converted to an int64 host tensor. - { - std::vector int64_values = {1, 2, 3}; - std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int64_values)); - Tensor host_tensor; - EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", - LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) - .error_message()); - EXPECT_EQ( - "Cannot convert literal of type S64 to tensor of type qint32", - LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor) - .error_message()); - EXPECT_TRUE( - LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int64_values)); - } + std::vector int64_values = {1, 2, 3}; + xla::Literal int64_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int64_values)); + Tensor host_tensor; + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", + LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor) + .error_message()); + EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32", + LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor) + .error_message()); + EXPECT_TRUE( + LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int64_values)); +} + +template +using LiteralUtilTest = ::testing::Test; +using Types = + ::testing::Types, std::pair, + std::pair, std::pair, + std::pair>; + +TYPED_TEST_CASE(LiteralUtilTest, Types); + +TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) { + using int_type = typename TypeParam::first_type; + using qint_type = typename TypeParam::second_type; - { - // Repeat tests with int32. - Tensor host_tensor; - std::vector int32_values = {10, 11}; - std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(absl::Span(int32_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) - .ok()); - test::ExpectTensorEqual(host_tensor, - test::AsTensor(int32_values)); + Tensor host_tensor; + std::vector int_values = {10, 11}; + xla::Literal int_values_literal = + xla::LiteralUtil::CreateR1(absl::Span(int_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, &host_tensor) + .ok()); + test::ExpectTensorEqual(host_tensor, + test::AsTensor(int_values)); - EXPECT_TRUE( - LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor) - .ok()); - std::vector qint32_values = {10, 11}; - test::ExpectTensorEqual(host_tensor, - test::AsTensor(qint32_values)); + EXPECT_TRUE(LiteralToHostTensor(int_values_literal, + DataTypeToEnum::value, + &host_tensor) + .ok()); + std::vector qint_values = {10, 11}; + test::ExpectTensorEqual(host_tensor, + test::AsTensor(qint_values)); - EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64", - LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor) - .error_message()); - } + EXPECT_EQ( + error::INVALID_ARGUMENT, + LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code()); } +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2cd9ae799f06afdcbae5429ef8caffd3b4d29c29..02363500efe1a11348eaf7d8b99da76307acdd3c 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -83,7 +83,7 @@ lhs_dilation: dilation to apply between input elements rhs_dilation: dilation to apply between kernel elements feature_group_count: number of feature groups for grouped convolution. dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDot") @@ -102,7 +102,36 @@ Wraps the XLA ConvGeneralDilated operator, documented at lhs: the LHS tensor rhs: the RHS tensor dimension_numbers: a serialized xla::DotDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. +)doc"); + +REGISTER_OP("XlaDynamicSlice") + .Input("input: T") + .Input("start_indices: Tindices") + .Input("size_indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA DynamicSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice +. + +DynamicSlice extracts a sub-array from the input array at dynamic +start_indices. The size of the slice in each dimension is passed in +size_indices, which specify the end point of exclusive slice intervals in each +dimension -- [start, start + size). The shape of start_indices must be rank == +1, with dimension size equal to the rank of operand. + +input: A `Tensor` of type T. + +start_indices: Rank 1 tensor of N integers containing the starting indices of + the slice for each dimension. Value must be greater than or equal to zero. + +start_indices: List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less )doc"); REGISTER_OP("XlaDynamicUpdateSlice") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 3626de375ea9ac12e40ea5b5b591bb6d5262adbc..27dd18a9bbd5aceece41aaf61eb185acb537b3b6 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): name=name) -def dynamic_slice(x, starts, sizes, name=None): - # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not - # a compile-time constant. This doesn't exactly mimic the semantics of dynamic - # slice if the slice is out of bounds. - return array_ops.slice(x, starts, sizes, name=name) - - +dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice # TODO(phawkins): generalize tf.pad to support interior padding, and then remove diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 32ba6df2e6daa2add468a1bc0559d42606d1a9a6..20f2ce2919701731ef6e90d368b67545af95e8f9 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { -/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( +/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( XlaResourceOpKind op_kind) { switch (op_kind) { case XlaResourceOpKind::kRead: @@ -30,11 +30,11 @@ namespace tensorflow { } } -static gtl::FlatMap* CreateResourceOpInfoMap() { - gtl::FlatMap* result = - new gtl::FlatMap; +static gtl::FlatMap* +CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap; - auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,23 +103,23 @@ static gtl::FlatMap* CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const gtl::FlatMap& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = + static gtl::FlatMap* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { - const gtl::FlatMap& op_infos = +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { + const gtl::FlatMap& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; } namespace resource_op_table_internal { -std::vector GetKnownResourceOps() { - std::vector result; +std::vector GetKnownResourceOps() { + std::vector result; for (const auto& p : GetStaticResourceOpInfoMap()) { result.push_back(p.first); } diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h index 7f627a64c6e8298a427cd87d25d4ba24835bf542..61c7a56ff0d4adb75e93ced3155b37102763c652 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.h +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/stringpiece.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" // Exposes information about the resource operations supported by tf2xla in a @@ -47,7 +47,7 @@ class XlaResourceOpInfo { XlaResourceOpKind kind() const { return op_kind_; } XlaResourceKind resource_kind() const { return resource_kind_; } - static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind); private: XlaResourceOpKind op_kind_; @@ -57,13 +57,13 @@ class XlaResourceOpInfo { // Returns a XlaResourceOpInfo describing `op` if it is a resource operation // supported by tf2xla, otherwise returns null (i.e. if this returns null then // `op` is either not a resource operation or is unsupported by XLA). -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op); namespace resource_op_table_internal { // NB! Implementation detail exposed for unit testing, do not use. // // Returns the set of resource operations known by this module. -std::vector GetKnownResourceOps(); +std::vector GetKnownResourceOps(); } // namespace resource_op_table_internal } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index 0343f80de9fed114a0097b981233277c3e12b378..a85ef040a7b65c2f6e405c3444eaef3019137b4b 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -34,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { TEST(ResourceOperationTableTest, HaveAllResourceOps) { gtl::FlatMap known_resource_ops; - for (StringPiece known_resource_op : + for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( known_resource_ops.insert({string(known_resource_op), false}).second); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 2d7eb8b915b8245ba6573c30b2eb15b12fc3a1b4..8aae498be1042b5a55e849a03d438cd54dafca83 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..6cd7b24592f30d7202b985f3dfd082ea2d85e344 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/side_effect_util.h" + +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes"; + +const char kXlaTokenArgNodeName[] = "_xla_token_arg_node"; + +std::set CalculateTokenInputsForOutputToken(const Graph& g) { + std::set results; + Node* first_side_effecting_node_on_path = nullptr; + ReverseDFS(g, + [&](Node* n) { + std::vector token_input_nodes; + if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, + &token_input_nodes) + .ok() || + token_input_nodes.empty()) { + return; + } + + if (first_side_effecting_node_on_path != nullptr) { + return; + } + + first_side_effecting_node_on_path = n; + results.insert(n->name()); + }, + [&](Node* n) { + if (first_side_effecting_node_on_path == n) { + first_side_effecting_node_on_path = nullptr; + } + }, + NodeComparatorName()); + return results; +} + +bool HasSideEffectingNodes(const Graph& g) { + for (Node* n : g.nodes()) { + std::vector token_input_nodes; + if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes) + .ok() && + !token_input_nodes.empty()) { + return true; + } + } + return false; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 0000000000000000000000000000000000000000..ad07624729f0b0d2443b2fc43d32dfa3377ce115 --- /dev/null +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index f34af2d67debe8bfa4abcad19e42c55ea40c4e82..b22d53805d83069052cc5e16020d6c540d618a82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -75,7 +76,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { // Strip off the aot_feed_#/ prefix. - StringPiece name(remap_it->second); + absl::string_view name(remap_it->second); const auto index = name.find('/'); if (index > 0) name.remove_prefix(index + 1); return errors::InvalidArgument( @@ -89,7 +90,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp) .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) .Attr(kFeedIdAttr, TensorIdToString(feed.id())) @@ -136,7 +137,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // Connects fetch_node -> retval_node. Node* retval_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp) .Input(fetch_node, id.output_index()) .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) .Attr("index", ret_index) @@ -256,7 +257,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( - strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); + absl::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); @@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), second_copy_def, g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + + // Functionalize control flow. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def)); + // After control flow functionalization, we might have more FunctionDef's + // (then/else branch, loop body). Add them to the graph. + TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto())); + *graph = std::move(g); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 56f7045a98201ed398244f9e3f5ff23788135b75..ab26d939ccba75ce58609ffd71c7ccadbe90cfa8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -77,8 +77,8 @@ TEST(ConvertGraphDefToXla, Sum) { // Set up arguments. auto x_literal = xla::LiteralUtil::CreateR0(10); auto y_literal = xla::LiteralUtil::CreateR0(32); - auto x_global_or = client->TransferToServer(*x_literal); - auto y_global_or = client->TransferToServer(*y_literal); + auto x_global_or = client->TransferToServer(x_literal); + auto y_global_or = client->TransferToServer(y_literal); TF_EXPECT_OK(x_global_or.status()); TF_EXPECT_OK(y_global_or.status()); std::unique_ptr x_global = @@ -90,8 +90,8 @@ TEST(ConvertGraphDefToXla, Sum) { auto result_or = client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); - std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); + xla::Literal result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("(s32[]) (\n42\n)", result.ToString()); config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index e284e0b191ac09f9491973166c80b731c8ea51a5..d6f42bac86f1ef359531d67b652d43d851d7ac02 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -20,20 +20,23 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace +const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation"; + Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { @@ -112,8 +117,8 @@ Status AddPlaceholdersForFeeds( const string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; - info.placeholder_name = strings::StrCat( - "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), + "/", feed->id().node_name()); (*feed_remapping)[name_port] = info.placeholder_name; } @@ -258,7 +263,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, } string TensorIdToString(const tf2xla::TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); + return absl::StrCat(id.node_name(), ":", id.output_index()); } Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { @@ -289,7 +294,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef) { for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { if (constraint.name() == name) { @@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } +// TODO(b/77601805): add tests for associated function related stuff. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr) { + if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) { + return true; + } + + if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + return false; + } + + for (const auto& iter : node_def.attr()) { + if (iter.second.has_func()) { + return true; + } + } + + return false; +} + +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr) { + std::vector results; + const string& op = node.type_string(); + if (flr->GetFunctionLibraryDefinition()->Contains(op)) { + // This is a function call node. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo(op, attrs)); + } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { + // Skip gradient op. Gradient op has "f" attr, which is set to the function + // we are getting gradient for. That function is not associated with the op. + } else { + // Collect all function attrs for the node. + for (auto& iter : node.attrs()) { + if (iter.second.has_func()) { + VLOG(2) << "Found function attr for node " << node.name() << ": " + << iter.first << " = " << iter.second.func().name(); + results.emplace_back(AssociatedFunctionInfo( + iter.second.func().name(), iter.second.func().attr(), iter.first)); + } + } + } + return results; +} + +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name) { + switch (associated_function.type()) { + case AssociatedFunctionInfo::kFunctionCallNode: { + // Change this node to call the new function. + NodeDefBuilder builder(node->name(), rewritten_function_name, fld); + for (auto attr : node->attrs()) { + builder.Attr(attr.first, attr.second); + } + for (int i = 0; i < node->num_inputs(); i++) { + Node* input_node; + TF_RETURN_IF_ERROR(node->input_node(i, &input_node)); + builder.Input(input_node->name(), i, node->input_type(i)); + } + builder.Device(node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name()); + NodeDef node_def; + TF_RETURN_IF_ERROR(builder.Finalize(&node_def)); + Status s; + Node* new_node = graph->AddNode(node_def, &s); + TF_RETURN_IF_ERROR(s); + for (auto edge : node->in_edges()) { + graph->AddEdge(edge->src(), edge->src_output(), new_node, + edge->dst_input()); + } + for (auto edge : node->out_edges()) { + graph->AddEdge(new_node, edge->src_output(), edge->dst(), + edge->dst_input()); + } + graph->RemoveNode(node); + break; + } + case AssociatedFunctionInfo::kFunctionAttr: { + // Change function attr to rewritten functions. + NameAttrList func; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), associated_function.attr_name(), &func)); + node->ClearAttr(associated_function.attr_name()); + func.set_name(rewritten_function_name); + node->AddAttr(associated_function.attr_name(), func); + break; + } + } + + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 33620ef810bd4fe897f384474e661e341a448b93..6065d0bb9a3abd23b8911c5049914be8a5f23b99 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -53,12 +54,73 @@ string TensorIdToString(const tf2xla::TensorId& id); Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. uint32 GetXLARandomSeed(); +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionCallNode = 0, + kFunctionAttr = 1, + }; + + // The node is a function call. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) + : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} + + // The function is an attr of the node. + AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, + const string& attr_name) + : type_(kFunctionAttr), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + FunctionLibraryRuntime* flr); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, FunctionLibraryRuntime* flr); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Attribute to mark nodes to be executed on host. +extern const char kXlaOutsideCompilationAttrName[]; + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 2b1f724dc7b2e2bb6d06115827f92bf0670955b3..68441b3d4790b17bd06accff3fcdc8ccee79bbb7 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -153,7 +153,7 @@ static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); - fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->set_name(absl::StrCat("fetch_", fetch_node_name)); fetch->mutable_id()->set_node_name(fetch_node_name); } return config; diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index c969212a1bfaa6cab0d896ee074cfd4e2b283ae4..d00b1376620c0c9d112c7d7426758f6d3f25e86f 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { *type = xla::PRED; return Status::OK(); case tensorflow::DT_INT8: + case tensorflow::DT_QINT8: *type = xla::S8; return Status::OK(); case tensorflow::DT_INT16: + case tensorflow::DT_QINT16: *type = xla::S16; return Status::OK(); case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: *type = xla::S32; return Status::OK(); case tensorflow::DT_INT64: *type = xla::S64; return Status::OK(); case tensorflow::DT_UINT8: + case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); case tensorflow::DT_UINT16: + case tensorflow::DT_QUINT16: *type = xla::U16; return Status::OK(); case tensorflow::DT_UINT32: @@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_COMPLEX64: *type = xla::C64; return Status::OK(); - case tensorflow::DT_QUINT8: - *type = xla::U8; - return Status::OK(); - case tensorflow::DT_QINT32: - *type = xla::S32; - return Status::OK(); default: return errors::InvalidArgument( "Unsupported type in DataTypeToPrimitiveType ", diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d98237bd5c9288e6337e10c19c2d7574ad2e4c97..7f860500c75667a920505dbf498e3da4b388fb90 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, DeviceType type) - : LocalDevice( - options, - Device::BuildDeviceAttributes( - strings::StrCat("/device:", type.type(), ":0"), type, - Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type()))), + : LocalDevice(options, Device::BuildDeviceAttributes( + absl::StrCat("/device:", type.type(), ":0"), + type, Bytes(256 << 20), DeviceLocality(), + absl::StrCat("device: XLA compilation device ", + type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0c300c282e9698534af6372b2f2ddae06f88db24..105f3b61d59acc7ed502216a5e0ceb69ee914131 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" @@ -149,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, TF_RETURN_WITH_CONTEXT_IF_ERROR( GetFunctionBody(function, flib_runtime_, fbody), "Local lookup failed with: ", status.error_message()); + VLOG(4) << "Function " << function.name() << " in flib_runtime_"; + } else { + VLOG(4) << "Function " << function.name() << " in local_flib_runtime_"; } return Status::OK(); } @@ -198,14 +201,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == "_Arg") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == "_Retval") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -213,8 +216,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_", function_id), - *graph); + absl::StrCat("xla_compile_function_", function_id), *graph); } VLOG(1) << "===================================================="; @@ -292,6 +294,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, "Invalid resource type in XLAShapeForArgument()"); } } + case XlaCompiler::Argument::kToken: { + *xla_shape = xla::ShapeUtil::MakeTokenShape(); + return Status::OK(); + } case XlaCompiler::Argument::kInvalid: return errors::Internal("Invalid argument type in XLAShapeForArgument()"); } @@ -490,7 +496,8 @@ Status XlaCompiler::BuildArguments( } break; - case XlaCompiler::Argument::kParameter: { + case XlaCompiler::Argument::kParameter: + case XlaCompiler::Argument::kToken: { input_mapping->push_back(i); break; } @@ -522,7 +529,7 @@ Status XlaCompiler::BuildArguments( // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { - if (StringPiece(n->type_string()) != "_Arg") continue; + if (absl::string_view(n->type_string()) != "_Arg") continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0 && index < args.size()) @@ -581,7 +588,7 @@ Status XlaCompiler::BuildArguments( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], - strings::StrCat("arg", i)); + absl::StrCat("arg", i)); } } @@ -617,6 +624,10 @@ Status XlaCompiler::BuildArguments( arg_expression.set_handle(arg_handles[i]); } break; + case XlaCompiler::Argument::kToken: { + arg_expression.set_handle(arg_handles[i]); + break; + } case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -644,7 +655,7 @@ Status XlaCompiler::CompileSingleOp( // dependency edge to the _SOURCE node. for (int64 i = 0; i < ctx->num_inputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); Status status = NodeBuilder(name, "_Arg") .ControlInput(graph->source_node()) .Attr("T", ctx->input_dtype(i)) @@ -657,7 +668,7 @@ Status XlaCompiler::CompileSingleOp( // Similarly with return values, create dummy _Retval nodes fed by `node`. for (int64 i = 0; i < ctx->num_outputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); Status status = NodeBuilder(name, "_Retval") .Input(main_node, i) .Attr("T", ctx->expected_output_dtype(i)) @@ -693,7 +704,7 @@ Status ValidateGraph(const Graph* graph, const DeviceType& device_type, const string& name) { auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", s.error_message(), ")", FormatNodeForError(*node))); @@ -734,18 +745,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph, + flib_runtime_->GetFunctionLibraryDefinition()); } // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); - // Converts Tensorflow's graph control-flow constructs into functional - // control-flow that can be compiled into XLA code. - TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), - graph.get(), local_flib_def_.get())); - // Detect invalid nodes. // FunctionalizeControlFlow may remove some nodes from the graph. TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, @@ -758,23 +764,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, &options_.shape_representation_fn); core::ScopedUnref context_unref(context); + std::vector real_args(args); + int token_input_index = -1; + if (options.add_token_input_output) { + // Add extra token input. + token_input_index = real_args.size(); + + XlaCompiler::Argument token_arg; + token_arg.kind = XlaCompiler::Argument::kToken; + real_args.push_back(token_arg); + } + std::vector arg_expressions; std::vector arg_cores; - TF_RETURN_IF_ERROR( - BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, - &arg_cores, &arg_expressions, &result->input_mapping, - &result->xla_input_shapes, options.is_entry_computation)); + TF_RETURN_IF_ERROR(BuildArguments( + *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes, + options.is_entry_computation)); context->set_args(std::move(arg_expressions)); + PushNodeTokenMapping(); + // Use std::set instead of std::unordered_set to ensure determinism. + std::set output_node_token_inputs; + if (token_input_index != -1) { + // Original token comes from input. + auto arg_expression = context->args()[token_input_index]; + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle())); + + // Calculate token inputs for output token. + output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph); + + // If there's no side-effecting op in the graph, use token input as token + // output. + if (output_node_token_inputs.empty()) { + output_node_token_inputs.insert(kXlaTokenArgNodeName); + } + } else if (options.is_entry_computation) { + // Original token is manually created. + if (HasSideEffectingNodes(*graph)) { + TF_RETURN_IF_ERROR( + SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder))); + } + } + TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, flib_runtime_, NextStepId())); + if (token_input_index != -1) { + // Add extra token output. + std::vector token_inputs; + for (const auto& node_name : output_node_token_inputs) { + auto token_or = GetNodeToken(node_name); + TF_RETURN_IF_ERROR(token_or.status()); + token_inputs.push_back(token_or.ValueOrDie()); + } + TF_RETURN_IF_ERROR( + context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs))); + } + TF_RETURN_IF_ERROR(PopNodeTokenMapping()); int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( - args, arg_cores, context->retvals(), context->resources(), + real_args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, @@ -913,4 +967,47 @@ Status XlaCompiler::SetHostComputeControlDependency( return Status::OK(); } +void XlaCompiler::PushNodeTokenMapping() { + node_token_mapping_stack_.emplace(std::map{}); +} + +Status XlaCompiler::PopNodeTokenMapping() { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " + "empty."); + } + node_token_mapping_stack_.pop(); + return Status::OK(); +} + +Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp& op) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling SetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto insert_result = node_token_mapping_stack_.top().insert({node_name, op}); + if (!insert_result.second) { + return errors::FailedPrecondition("Token mapping already exists for node ", + node_name); + } + return Status::OK(); +} + +xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { + if (node_token_mapping_stack_.empty()) { + return errors::FailedPrecondition( + "Calling GetNodeToken() when node_token_mapping_stack_ is " + "empty."); + } + auto iter = node_token_mapping_stack_.top().find(node_name); + if (iter == node_token_mapping_stack_.top().end()) { + return errors::FailedPrecondition("Cannot find token mapping for node ", + node_name); + } + return iter->second; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 8f4a9858ed63403b9d0f967b61d3f690f12df21a..2cc603a58016a509fafdf6f95423dd6c0864cce3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#include + #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" @@ -106,6 +109,9 @@ class XlaCompiler { // Argument is a run-time parameter. kParameter, + + // Argument is an XLA token. + kToken, }; Kind kind = kInvalid; @@ -179,6 +185,9 @@ class XlaCompiler { // True when compiling the entry computation, false for subcomputations // (while, call, etc.) bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; }; struct OutputDescription { @@ -384,6 +393,11 @@ class XlaCompiler { xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + void PushNodeTokenMapping(); + Status PopNodeTokenMapping(); + Status SetNodeToken(const string& node_name, const xla::XlaOp& op); + xla::StatusOr GetNodeToken(const string& node_name); + private: // Sets the function body `fbody` to the one registered as `function`. Status FindFunctionBody(const NameAttrList& function, @@ -448,6 +462,15 @@ class XlaCompiler { std::unordered_map host_compute_control_output_; + // This is used to store mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack> node_token_mapping_stack_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index be3c93ae47bf16a67ed4fac34a99997cc7888559..72b17d04fc42eb00781e96b412465b73fb29a5c2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/side_effect_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -32,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -205,27 +208,22 @@ TEST_F(XlaCompilerTest, Simple) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation of a graph where the _Retval node is not necessarily last @@ -261,23 +259,20 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal)); } // Tests that the compiler doesn't reorder the parameters. @@ -405,23 +400,19 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get()}); - EXPECT_TRUE( - xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } { @@ -440,24 +431,21 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { EXPECT_FALSE(result.outputs[1].is_constant); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); - std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); + xla::Literal expected0 = xla::LiteralUtil::CreateR0(7); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal)); } } @@ -616,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { auto instr1 = c1.instructions(j); auto instr2 = c2.instructions(j); instr1.clear_name(); + instr1.clear_id(); + instr1.clear_operand_ids(); instr2.clear_name(); - // The names of instructions were uniquified by the XlaBuilder, the rest - // of the fields should be identical. + instr2.clear_id(); + instr2.clear_operand_ids(); + // The names of instructions were uniquified by the XlaBuilder and the + // unique ids may be different, the rest of the fields should be + // identical. string str1, str2; + LOG(INFO) << "instr1 = " << instr1.DebugString(); + LOG(INFO) << "instr2 = " << instr2.DebugString(); instr1.AppendPartialToString(&str1); instr2.AppendPartialToString(&str2); EXPECT_EQ(str1, str2); @@ -669,34 +664,26 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { update.tensor_array_gradients_accessed); // Tests that the generated computation works. - std::unique_ptr input_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr input_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr input = - xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()}); + xla::Literal input_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal input_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2}); std::unique_ptr param0_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); + client_->TransferToServer(input).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param0_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr output_read = - xla::LiteralUtil::CreateR0(42); - std::unique_ptr output_base = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr output_grad1 = - xla::LiteralUtil::CreateR1({0, 1}); - std::unique_ptr output_grad2 = - xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr output_resource = xla::LiteralUtil::MakeTuple( - {output_base.get(), output_grad1.get(), output_grad2.get()}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal output_read = xla::LiteralUtil::CreateR0(42); + xla::Literal output_base = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal output_grad1 = xla::LiteralUtil::CreateR1({0, 1}); + xla::Literal output_grad2 = xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal output_resource = + xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&output_read, &output_resource}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -863,29 +850,24 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { void RunAndCheckVariablesComputation( xla::Client* client, const XlaCompiler::CompilationResult& result) { - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({5, 144}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a simple graph that reads and writes a variable. @@ -949,20 +931,17 @@ TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ReturnResourceHandle) { @@ -1066,29 +1045,27 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR2({{4, 55}, {1, -3}}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = + xla::Literal expected0 = xla::LiteralUtil::CreateR2({{27, 67}, {35, 402}}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { @@ -1135,29 +1112,26 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({4, 55, 1, -3}); - std::unique_ptr param1_literal = + xla::Literal param1_literal = xla::LiteralUtil::CreateR1({22, 11, 33, 404}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = client_ ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) .ConsumeValueOrDie(); - std::unique_ptr actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({27, 67, 35, 402}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({26, 66, 34, 401}); - std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); + + xla::Literal expected0 = xla::LiteralUtil::CreateR1({27, 67, 35, 402}); + xla::Literal expected1 = xla::LiteralUtil::CreateR1({26, 66, 34, 401}); + xla::Literal expected_literal = + xla::LiteralUtil::MakeTuple({&expected0, &expected1}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } // Tests a graph which has a function with an invalid op. @@ -1252,25 +1226,73 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); XlaCompiler::CompilationResult result; - status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result); - ASSERT_FALSE(status.ok()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) - << status.error_message(); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + } +} + +class DummySideEffectingOp : public XlaOpKernel { + public: + explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken( + name(), xla::CreateToken(ctx->builder()))); } +}; + +REGISTER_OP("DummySideEffectingOp"); + +REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp); + +TEST_F(XlaCompilerTest, TokenInputAndOutput) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef side_effecting_op; + side_effecting_op.set_name("DummySideEffectingOp"); + side_effecting_op.set_op("DummySideEffectingOp"); + AddNodeAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}, &side_effecting_op); + Status status; + graph->AddNode(side_effecting_op, &status); + TF_ASSERT_OK(status); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); - // Fix control edges for NoOp. + const std::vector empty_args; { + // The case for entry computation: we don't add token input/output. Instead, + // we use CreateToken HLO to create the entry token. + XlaCompiler::CompileOptions options; + options.is_entry_computation = true; + options.add_token_input_output = false; + XlaCompiler compiler(DefaultOptions()); + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); CopyGraph(*graph, graph_copy.get()); - EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", - std::move(graph_copy), args, &result)); - EXPECT_EQ(0, result.resource_updates.size()); + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 0); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0); + } + { + // The case for non-entry computation (e.g. while loop body). We add token + // input/output. + XlaCompiler::CompileOptions options; + options.is_entry_computation = false; + options.add_token_input_output = true; + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy), + empty_args, &result)); + EXPECT_EQ(result.xla_input_shapes.size(), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0])); + EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape)); + EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1); + EXPECT_TRUE(xla::ShapeUtil::IsToken( + xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0))); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 24a4b92b45a3f3563e435fa074fce595d6c0b263..f247570d72c0287a33695de3d778cce2a2418921 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -120,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { return Status::OK(); } +Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) { + VLOG(1) << "Adding retval index " << retvals_.size() + << " with token to XLA computation"; + XlaExpression e; + e.set_handle(token); + // We use DT_INVALID because there is no TF DataType which corresponds to XLA + // token. XlaCompiler handles this case separately, so putting it here is OK. + retvals_.push_back(Retval{DT_INVALID, TensorShape(), e}); + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 4da891634e97dd67af0ef09ef33dbc7a4d19743b..d7dbdc957f0e7969db5098b815381866cdc71ab6 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -89,6 +89,9 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are resource handles. Status AddResourceRetval(int retval_index, XlaResource* resource); + // As for Retval, but for return values that are XLA tokens. + Status AppendTokenRetval(const xla::XlaOp& token); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 1499c99ed15eceaf6bfa2ef0dd1d5885b1e5fc58..2a9eaeee146bf6d792e010df7e041f9986b2c77e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { +const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { return GetComputationFromTensor(GetInputTensorByName(name)); } @@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } -TensorShape XlaOpKernelContext::InputShape(StringPiece name) { +TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { return GetInputTensorByName(name).shape(); } @@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const { return context_->input(index).dtype(); } +DataType XlaOpKernelContext::InputType(absl::string_view name) { + return GetInputTensorByName(name).dtype(); +} + xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; Status status = DataTypeToPrimitiveType(input_type(index), &type); @@ -100,7 +104,7 @@ Status XlaOpKernelContext::ConstantInput(int index, } static xla::StatusOr InputIndex(XlaOpKernelContext* context, - StringPiece name) { + absl::string_view name) { int start, stop; TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -112,7 +116,7 @@ static xla::StatusOr InputIndex(XlaOpKernelContext* context, return start; } -Status XlaOpKernelContext::ConstantInput(StringPiece name, +Status XlaOpKernelContext::ConstantInput(absl::string_view name, xla::Literal* constant_literal) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInput(index, constant_literal); @@ -213,16 +217,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( context_->op_kernel().name(), " input ", index, ".\nError: ", constant_graph.status().error_message()); } - xla::StatusOr> computed = - compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), - &layout); + xla::StatusOr computed = compiler()->client()->ComputeConstant( + constant_graph.ValueOrDie(), &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, - "as a compile-time constant.\nError: ", + " as a compile-time constant.\nError: ", computed.status().error_message()); } - *constant_literal = std::move(*computed.ValueOrDie()); + *constant_literal = std::move(computed).ValueOrDie(); return Status::OK(); } @@ -265,7 +268,7 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name, int64* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntScalar(index, out); @@ -305,7 +308,7 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name, std::vector* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntVector(index, out); @@ -344,7 +347,7 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } -Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsInt64Literal(index, out); @@ -361,7 +364,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList(StringPiece name, +Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; @@ -376,7 +379,7 @@ Status XlaOpKernelContext::InputList(StringPiece name, } Status XlaOpKernelContext::ConstantInputList( - StringPiece name, std::vector* outputs) { + absl::string_view name, std::vector* outputs) { int start, stop; TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); @@ -429,8 +432,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, value); } -Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, - TensorShape* shape, +Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, + DataType type, TensorShape* shape, xla::XlaOp* value) { return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, shape, value); @@ -564,7 +567,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, handle, builder()); } -Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, +Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(GetInputTensorByName(name), type, context_, @@ -610,7 +613,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } -const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { +const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; CHECK(context_->input(name, &tensor).ok()); return *tensor; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 45cfa7da740c38afde0158568a019a4426992b64..a3a0d10cc06cd4afceec728b7dbe287389099b9d 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -71,6 +71,9 @@ class XlaOpKernelContext { // Returns the type of input `index`. DataType input_type(int index) const; + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + // Returns the type of input `index` as an xla::PrimitiveType. If the type // is not representable as an XLA type, sets an error status and returns // xla::PRIMITIVE_TYPE_INVALID. @@ -79,15 +82,15 @@ class XlaOpKernelContext { // Returns the shape of input `index`. TensorShape InputShape(int index); - // Returns the shape of input `name`. - TensorShape InputShape(StringPiece name); + // Returns the shape of input with name `name`. + TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(StringPiece name); + const xla::XlaOp& Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -97,7 +100,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, std::vector* handles, + Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -106,7 +109,7 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - Status ConstantInput(StringPiece name, xla::Literal* constant_literal); + Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input @@ -118,14 +121,15 @@ class XlaOpKernelContext { // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); - Status ConstantInputAsIntScalar(StringPiece name, int64* out); + Status ConstantInputAsIntScalar(absl::string_view name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); - Status ConstantInputAsIntVector(StringPiece name, std::vector* out); + Status ConstantInputAsIntVector(absl::string_view name, + std::vector* out); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. @@ -133,7 +137,7 @@ class XlaOpKernelContext { // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); - Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); + Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -141,7 +145,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status ConstantInputList(StringPiece name, + Status ConstantInputList(absl::string_view name, std::vector* literals); // Outputs @@ -190,8 +194,8 @@ class XlaOpKernelContext { xla::XlaOp* value); // Reads the current value of the resouce variable referred to by input // `name`. - Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, - xla::XlaOp* value); + Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the @@ -199,7 +203,8 @@ class XlaOpKernelContext { // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Assigns the value `handle` to the variable referenced by input `name`. - Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); + Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -248,7 +253,7 @@ class XlaOpKernelContext { private: // Returns the tensor of input `name`. - const Tensor& GetInputTensorByName(StringPiece name); + const Tensor& GetInputTensorByName(absl::string_view name); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index dae2d956ca61a18f7da61fcd0a569a55a6286663..b0eeee3174eda7f552f1d8a1d5ece877e93f94ab 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -371,26 +371,28 @@ XlaOpRegistry& XlaOpRegistry::Instance() { return *r; } -XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { +XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); registration_->name = string(name); } -XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { +XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( + absl::string_view name) { XlaOpRegistrationBuilder registration(name); return registration; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - absl::Span devices) { + absl::Span devices) { registration_->has_device_whitelist = true; - for (StringPiece device : devices) { + for (absl::string_view device : devices) { registration_->device_whitelist.emplace(device); } return *this; } -XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( + absl::string_view device) { registration_->has_device_whitelist = true; registration_->device_whitelist.emplace(device); return *this; @@ -407,7 +409,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, DataType allowed) { + absl::string_view attr_name, DataType allowed) { std::set& types = registration_->type_constraints[string(attr_name)]; types.insert(allowed); @@ -415,7 +417,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, absl::Span allowed) { + absl::string_view attr_name, absl::Span allowed) { std::set& types = registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { @@ -425,7 +427,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( - StringPiece input_name) { + absl::string_view input_name) { registration_->compile_time_constant_inputs.emplace(input_name); return *this; } @@ -452,7 +454,7 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, absl::Span types, + absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); registry.RegisterBackend(string(name), types, op_filter); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index c640842dc0d4fb3aff64d8388b4ffd3fdcee9faf..74a4885f1f029628817f6ec3a36fcb98719d6a41 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -232,18 +232,18 @@ class XlaOpRegistry { class XlaOpRegistrationBuilder { public: // Starts an operator registration chain. - static XlaOpRegistrationBuilder Name(StringPiece name); + static XlaOpRegistrationBuilder Name(absl::string_view name); // Specifies a whitelist of devices on which the operator may run. - XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(absl::Span devices); + XlaOpRegistrationBuilder& Device(absl::string_view devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, DataType allowed); - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on @@ -254,13 +254,13 @@ class XlaOpRegistrationBuilder { XlaOpRegistrationBuilder& AllowResourceTypes(); // Mark 'input_name' as an argument whose value must be known at compile-time. - XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); std::unique_ptr Build( XlaOpRegistry::Factory factory); private: - XlaOpRegistrationBuilder(StringPiece name); + XlaOpRegistrationBuilder(absl::string_view name); std::unique_ptr registration_; }; @@ -288,7 +288,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, absl::Span types, + XlaBackendRegistrar(absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 7928fa034725206a752cbfe086d01f15cd235df9..56c2e01055665954b99ea635e56666fbd8b96026 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -43,7 +43,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -135,7 +135,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, gradient_value, tensor_array_size_, /*tensor_array_gradients=*/{})); } diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 8818f813127230d3b39d4b48d874b7cfb24b8abc..5dde5b432f136c16d4e3795569499ee5de709763 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -37,8 +37,8 @@ Client::Client(ServiceInterface* stub) : stub_(stub) {} Client::~Client() = default; -StatusOr> Client::Transfer( - const GlobalData& data, const Shape* shape_with_layout) { +StatusOr Client::Transfer(const GlobalData& data, + const Shape* shape_with_layout) { TransferToClientRequest request; *request.mutable_data() = data.handle(); if (shape_with_layout != nullptr) { @@ -114,7 +114,7 @@ Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, return Status::OK(); } -StatusOr> Client::TransferFromOutfeed( +StatusOr Client::TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id, const DeviceHandle* device_handle) { TransferFromOutfeedRequest request; @@ -162,7 +162,7 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( +StatusOr Client::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options, ExecutionProfile* execution_profile) { @@ -177,8 +177,8 @@ StatusOr> Client::ExecuteAndTransfer( return Transfer(*data, shape_with_output_layout); } -StatusOr> Client::ComputeConstant( - const XlaComputation& computation, const Layout* output_layout) const { +StatusOr Client::ComputeConstant(const XlaComputation& computation, + const Layout* output_layout) const { ComputeConstantGraphRequest request; *request.mutable_computation() = computation.proto(); if (output_layout != nullptr) { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 7960b078686e611a6439af495d266f9084992d29..6f4d33c469f1f885cfeef546e3981dc3417ef71f 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -96,8 +96,8 @@ class Client { // // If shape_with_layout is not nullptr, it points to a shape whose layout will // be the layout of the returned literal. - StatusOr> Transfer( - const GlobalData& data, const Shape* shape_with_layout = nullptr); + StatusOr Transfer(const GlobalData& data, + const Shape* shape_with_layout = nullptr); // Transfer the given literal to the server. This allocates memory on the // device and copies the literal's contents over. Returns a global data handle @@ -122,7 +122,7 @@ class Client { // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - StatusOr> TransferFromOutfeed( + StatusOr TransferFromOutfeed( const Shape* shape_with_layout, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); @@ -132,7 +132,7 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const ExecutionOptions* execution_options = nullptr, @@ -153,7 +153,7 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - StatusOr> ComputeConstant( + StatusOr ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 6861521acc0db1d640666a6793b898a183ab6a17..25cc37edc43c28a636797c310c8882eea09a0ef3 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -76,7 +76,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { if (DataSizeOfShape(shape) < (1LL << 20)) { - StatusOr> literal_status = MakeFakeLiteral(shape); + StatusOr literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via // an on-device computation. @@ -84,7 +84,7 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, tensorflow::error::UNIMPLEMENTED); return MakeFakeDataViaDeviceOrDie(shape, client); } - return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie(); + return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie(); } // If the data is large, generate it on-device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 4402ba8762c1538951c326c880fc3b6dd63ef0c6..f96b6c9c261a9686fb647e3da0dcc933cd1f70df 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -195,9 +195,8 @@ Status LocalExecutable::RecordArguments( HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); + *hlo_snapshot->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -205,13 +204,12 @@ Status LocalExecutable::RecordArguments( Status LocalExecutable::RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot) { hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal->ToProto(); + TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); + *hlo_snapshot->mutable_result() = literal.ToProto(); return Status::OK(); } -StatusOr> LocalExecutable::LiteralFromShapedBuffer( +StatusOr LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, backend_->BorrowStream(shaped_buffer.device_ordinal())); @@ -277,7 +275,7 @@ StatusOr LocalClient::LiteralToShapedBuffer( return std::move(scoped_buffer); } -StatusOr> LocalClient::ShapedBufferToLiteral( +StatusOr LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( shaped_buffer.device_ordinal())); @@ -298,13 +296,13 @@ Status LocalClient::TransferToInfeedLocal(const Literal& literal, literal); } -StatusOr> LocalClient::TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal) { +StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); auto literal = Literal::CreateFromShape(shape); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, shape, literal.get())); + executor, shape, &literal)); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 56c3a3da023ebf124b4bd91c2c608d0cd00a2381..feb2f8ec9dab5bf13afdc866d10ccbe74f8edcb9 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -84,8 +84,7 @@ class LocalExecutable { Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. - StatusOr> LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer); + StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); // The ordinal of the device which this executable was compiled for. The // executable can run on all equivalent devices (as determined by @@ -132,8 +131,7 @@ class LocalClient : public Client { // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. - StatusOr> ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + StatusOr ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid // as long as the handle is valid. @@ -151,8 +149,8 @@ class LocalClient : public Client { // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr> TransferFromOutfeedLocal( - const Shape& shape, int device_ordinal); + StatusOr TransferFromOutfeedLocal(const Shape& shape, + int device_ordinal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e639028ccda11ae7e873f601c2f95749bce178c0..95ff6432a591f87845729b180397e33a85e5e9a5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn( StatusOr XlaBuilder::GetProgramShape(int64 root_id) const { TF_RETURN_IF_ERROR(first_error_); - TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size())); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto, + LookUpInstructionByHandle(root_id)); ProgramShape program_shape; - *program_shape.mutable_result() = instructions_[root_id].shape(); + *program_shape.mutable_result() = root_proto->shape(); // Check that the parameter numbers are continuous from 0, and add parameter // shapes and names to the program shape. @@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, return; } - CHECK(op_handle < instructions_.size() && op_handle >= 0); - - const HloInstructionProto& instr = instructions_[op_handle]; + const HloInstructionProto& instr = + *(LookUpInstructionByHandle(op_handle).ValueOrDie()); const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie(); switch (opcode) { default: @@ -283,6 +283,7 @@ StatusOr XlaBuilder::Build(int64 root_id) { // Clear data held by this builder. this->instructions_.clear(); + this->handle_to_index_.clear(); this->embedded_.clear(); this->parameter_numbers_.clear(); @@ -738,7 +739,7 @@ void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = ShapeUtil::MakeNil(); - *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto(); + *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto(); return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); }); } @@ -820,7 +821,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -828,14 +829,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } -XlaOp XlaBuilder::DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -844,8 +844,8 @@ XlaOp XlaBuilder::DotGeneral( ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); @@ -899,28 +899,26 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -948,7 +946,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); }); } @@ -956,11 +954,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -968,8 +965,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -990,14 +986,14 @@ XlaOp XlaBuilder::ConvGeneralDilated( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, instr.window(), - dimension_numbers, feature_group_count)); + lhs_shape, rhs_shape, feature_group_count, + instr.window(), dimension_numbers)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kConvolution, @@ -2290,7 +2286,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is - // also a valid denpendency order). The related ops will be added to the + // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set related_ops; tensorflow::gtl::FlatSet related_calls; // Related computations. @@ -2298,14 +2294,16 @@ StatusOr XlaBuilder::BuildConstantSubGraph( worklist.push(root->id()); related_ops.insert(root->id()); while (!worklist.empty()) { - int64 node = worklist.front(); + int64 handle = worklist.front(); worklist.pop(); - for (int64 id : instructions_[node].operand_ids()) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(handle)); + for (int64 id : instr_proto->operand_ids()) { if (related_ops.insert(id).second) { worklist.push(id); } } - for (int64 called_id : instructions_[node].called_computation_ids()) { + for (int64 called_id : instr_proto->called_computation_ids()) { related_calls.insert(called_id); } } @@ -2313,7 +2311,9 @@ StatusOr XlaBuilder::BuildConstantSubGraph( // Add related ops to the computation. for (int64 id : related_ops) { auto* instr = entry.add_instructions(); - *instr = instructions_[id]; + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src, + LookUpInstructionByHandle(id)); + *instr = *instr_src; // Ensures that the instruction names are unique among the graph. const string& new_name = StrCat(instr->name(), ".", entry.id(), ".", instr->id()); @@ -2420,11 +2420,11 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, absl::Span operands) { TF_RETURN_IF_ERROR(first_error_); - const int64 handle = instructions_.size(); + const int64 handle = GetUniqueId(); instr.set_id(handle); instr.set_opcode(HloOpcodeString(opcode)); if (instr.name().empty()) { - instr.set_name(StrCat(instr.opcode())); + instr.set_name(instr.opcode()); } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { @@ -2442,7 +2442,8 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, *instr.mutable_sharding() = *sharding_; } - instructions_.push_back(instr); + handle_to_index_[handle] = instructions_.size(); + instructions_.push_back(std::move(instr)); XlaOp op(handle, this); return op; @@ -2472,10 +2473,16 @@ StatusOr XlaBuilder::LookUpInstruction( op.handle(), op.builder_->name(), this->name()); } - if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %d", op.handle()); + return LookUpInstructionByHandle(op.handle()); +} + +StatusOr XlaBuilder::LookUpInstructionByHandle( + int64 handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions_[op.handle()]; + return &instructions_[it->second]; } // Enqueues a "retrieve parameter value" instruction for a parameter that was @@ -2594,43 +2601,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->Dot(lhs, rhs, precision_config_proto); + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config_proto); + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count, - precision_config_proto); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -2638,10 +2642,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2651,10 +2655,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config_proto); + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 59fbc664f2b35fd00f9b9094d6147847d03797ea..d0c59fa6f27bc265c0868734ed95a196002fbd2e 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" @@ -496,20 +497,19 @@ class XlaBuilder { // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -518,7 +518,7 @@ class XlaBuilder { absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -527,29 +527,27 @@ class XlaBuilder { absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -958,6 +956,8 @@ class XlaBuilder { HloInstructionProto* instr); StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstructionByHandle( + int64 handle) const; // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -1027,6 +1027,10 @@ class XlaBuilder { // The instructions of this computation. std::vector instructions_; + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + tensorflow::gtl::FlatMap handle_to_index_; + // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of // that XlaComputation. @@ -1150,32 +1154,30 @@ class XlaBuilder { friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -1183,8 +1185,7 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1629,27 +1630,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1657,7 +1658,7 @@ XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -1666,17 +1667,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -2117,12 +2119,12 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, template XlaOp XlaBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*LiteralUtil::CreateR0(value)); + return ConstantLiteral(LiteralUtil::CreateR0(value)); } template XlaOp XlaBuilder::ConstantR1(absl::Span values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template @@ -2134,44 +2136,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) { } inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*LiteralUtil::CreateR1(values)); + return ConstantLiteral(LiteralUtil::CreateR1(values)); } template XlaOp XlaBuilder::ConstantR2( std::initializer_list> values) { - return ConstantLiteral(*LiteralUtil::CreateR2(values)); + return ConstantLiteral(LiteralUtil::CreateR2(values)); } template XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantFromArray(const Array& values) { - return ConstantLiteral(*LiteralUtil::CreateFromArray(values)); + return ConstantLiteral(LiteralUtil::CreateFromArray(values)); } template XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D& values) { - return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D(values)); + return ConstantLiteral(LiteralUtil::CreateR2FromArray2D(values)); } template XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantLiteral( - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template @@ -2194,12 +2196,12 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { template XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, *LiteralUtil::CreateR0(value)); + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); } template XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template @@ -2212,13 +2214,13 @@ XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) { inline XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR1(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); } template XlaOp ConstantR2(XlaBuilder* builder, std::initializer_list> values) { - return ConstantLiteral(builder, *LiteralUtil::CreateR2(values)); + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); } template @@ -2226,14 +2228,13 @@ XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, const Array& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateFromArray(values)); + LiteralUtil::CreateFromArray(values)); } template @@ -2241,15 +2242,14 @@ XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, const Array2D& values, const Layout& layout) { return ConstantLiteral( - builder, - *LiteralUtil::CreateFromArrayWithLayout(values, layout)); + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); } template XlaOp ConstantR2FromArray2D(XlaBuilder* builder, const Array2D& values) { return ConstantLiteral(builder, - *LiteralUtil::CreateR2FromArray2D(values)); + LiteralUtil::CreateR2FromArray2D(values)); } template @@ -2258,7 +2258,7 @@ XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, const Layout& layout) { return ConstantLiteral( builder, - *LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 3f7635bd400c6ec87e0e3a739658272e906a72fb..5035f4198890857fcafd0156d7eaeeb4bc164322 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -174,9 +174,9 @@ Literal& Literal::operator=(Literal&& other) { return *this; } -std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { - auto literal = absl::make_unique(shape); - literal->root_piece_->ForEachMutableSubpiece( +Literal LiteralBase::CreateFromShape(const Shape& shape) { + Literal literal(shape); + literal.root_piece_->ForEachMutableSubpiece( [&](const ShapeIndex& index, Piece* piece) { if (ShapeUtil::IsArray(piece->subshape())) { memset(piece->untyped_data(), 0, piece->size_bytes()); @@ -278,8 +278,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } -/* static */ StatusOr> -MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { +/* static */ StatusOr MutableLiteralBase::CreateFromProto( + const LiteralProto& proto) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -287,9 +287,9 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { return InvalidArgument("LiteralProto has no layout"); } - auto literal = absl::make_unique(proto.shape()); + Literal literal(proto.shape()); - TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { const LiteralProto* proto_element = &proto; for (int64 i : index) { @@ -556,38 +556,37 @@ void MutableLiteralBase::PopulateR1(const tensorflow::core::Bitmap& values) { } } -std::unique_ptr LiteralBase::Relayout( - const Layout& new_layout, const ShapeIndex& shape_index) const { +Literal LiteralBase::Relayout(const Layout& new_layout, + const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); *subshape->mutable_layout() = new_layout; - auto result = absl::make_unique(new_shape); - TF_CHECK_OK(result->CopyFrom(*this)); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr LiteralBase::Relayout( - const Shape& shape_with_layout) const { +Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) << " not compatible with literal shape " << ShapeUtil::HumanString(shape()); - std::unique_ptr result = CreateFromShape(shape_with_layout); + Literal result = CreateFromShape(shape_with_layout); ShapeUtil::ForEachSubshape( - result->shape(), + result.shape(), [this, &result](const Shape& subshape, const ShapeIndex& index) { if (ShapeUtil::IsArray(subshape)) { - TF_CHECK_OK(result->CopyFrom(*this, - /*dest_shape_index=*/index, - /*src_shape_index=*/index)); + TF_CHECK_OK(result.CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); } }); return result; } -StatusOr> LiteralBase::Broadcast( +StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Broadcast only supports arrays."); @@ -598,14 +597,14 @@ StatusOr> LiteralBase::Broadcast( result_shape.dimensions(dimensions[i])); } - std::unique_ptr result = absl::make_unique(result_shape); + Literal result(result_shape); // scratch_source_index is temporary storage space for the computed index into // the input literal. We put it here to avoid allocating an std::vector in // every iteration of ShapeUtil::ForEachIndex. std::vector scratch_source_index(shape().dimensions_size()); - char* dest_data = static_cast(result->untyped_data()); + char* dest_data = static_cast(result.untyped_data()); const char* source_data = static_cast(untyped_data()); const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); @@ -627,37 +626,36 @@ StatusOr> LiteralBase::Broadcast( return std::move(result); } -StatusOr> LiteralBase::Reshape( +StatusOr LiteralBase::Reshape( absl::Span dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } - std::unique_ptr output; + Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { - output = CloneToUnique(); + output = Clone(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape_do_not_use() = + *output.mutable_shape_do_not_use() = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); - int64 elements_after = ShapeUtil::ElementsIn(output->shape()); + int64 elements_after = ShapeUtil::ElementsIn(output.shape()); if (elements_before != elements_after) { return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", ShapeUtil::HumanString(shape()), - ShapeUtil::HumanString(output->shape())); + ShapeUtil::HumanString(output.shape())); } return std::move(output); } -std::unique_ptr LiteralBase::Transpose( - absl::Span permutation) const { +Literal LiteralBase::Transpose(absl::Span permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; @@ -687,32 +685,31 @@ std::unique_ptr LiteralBase::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - auto new_literal = absl::make_unique(permuted_shape); - DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), + Literal new_literal(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); + std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); return new_literal; } template -std::unique_ptr LiteralBase::SliceInternal( +Literal LiteralBase::SliceInternal( const Shape& result_shape, absl::Span start_indices) const { - auto result_literal = absl::make_unique(result_shape); + Literal result_literal(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); - result_literal->EachCell( + result_literal.EachCell( [&](absl::Span indices, NativeT /*value*/) { for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { new_indices[i] = indices[i] + start_indices[i]; } NativeT value = Get(new_indices); - result_literal->Set(indices, value); + result_literal.Set(indices, value); }); return result_literal; } -std::unique_ptr LiteralBase::Slice( - absl::Span start_indices, - absl::Span limit_indices) const { +Literal LiteralBase::Slice(absl::Span start_indices, + absl::Span limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; @@ -750,12 +747,6 @@ Literal LiteralBase::Clone() const { return result; } -std::unique_ptr LiteralBase::CloneToUnique() const { - auto result = absl::make_unique(shape()); - TF_CHECK_OK(result->CopyFrom(*this)); - return result; -} - string LiteralBase::GetAsString(absl::Span multi_index, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); @@ -1191,14 +1182,14 @@ void LiteralBase::EachCellAsString( namespace { template -std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const LiteralBase& src_literal, const ConverterType& converter) { +Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal, + const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique(ShapeUtil::ChangeElementType( + Literal result_literal(ShapeUtil::ChangeElementType( src_literal.shape(), primitive_util::NativeToPrimitiveType())); auto src_data = src_literal.data(); - auto dest_data = result_literal->template data(); + auto dest_data = result_literal.template data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { @@ -1208,8 +1199,7 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes( - const LiteralBase& src_literal) { +Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1217,7 +1207,7 @@ std::unique_ptr ConvertBetweenNativeTypes( template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); @@ -1232,20 +1222,20 @@ BitcastBetweenNativeTypes(const LiteralBase& src_literal) { // identical sizes higher up. template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), - std::unique_ptr>::type + Literal>::type BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { +Literal ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); - auto result_literal = absl::make_unique( + Literal result_literal( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; absl::Span src_data = src_literal.data(); - absl::Span dest_data = result_literal->data(); + absl::Span dest_data = result_literal.data(); int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); @@ -1254,8 +1244,7 @@ std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, - bool bitcast) { +Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { return BitcastBetweenNativeTypes< @@ -1273,9 +1262,9 @@ std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, } template -StatusOr> ConvertIfDestTypeMatches( - const LiteralBase& src_literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertIfDestTypeMatches(const LiteralBase& src_literal, + PrimitiveType primitive_dest_type, + bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ case (type): \ @@ -1307,12 +1296,12 @@ StatusOr> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type)); } -StatusOr> ConvertSwitch( - const LiteralBase& literal, PrimitiveType primitive_dest_type, - bool bitcast) { +StatusOr ConvertSwitch(const LiteralBase& literal, + PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { - return literal.CloneToUnique(); + return literal.Clone(); } switch (literal.shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ @@ -1342,12 +1331,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> LiteralBase::Convert( +StatusOr LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> LiteralBase::BitcastConvert( +StatusOr LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1362,17 +1351,8 @@ StatusOr> LiteralBase::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> LiteralBase::ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16) const { +StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { if (!ShapeUtil::IsTuple(dest_shape)) { - if (round_f32_to_bf16 && shape().element_type() == F32 && - dest_shape.element_type() == BF16) { - auto converter = [](float src) { - return tensorflow::bfloat16::round_to_bfloat16(src); - }; - return ConvertBetweenNativeTypesWithConverter(*this, - converter); - } return Convert(dest_shape.element_type()); } std::vector elements; @@ -1381,11 +1361,9 @@ StatusOr> LiteralBase::ConvertToShape( TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); - elements.push_back(std::move(*new_element)); + elements.push_back(std::move(new_element)); } - auto converted = absl::make_unique(); - *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); - return std::move(converted); + return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements)); } /* static */ Literal MutableLiteralBase::MoveIntoTuple( @@ -1782,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case S8: + proto->set_s8s(static_cast(data().data()), + element_count()); + break; case U8: proto->set_u8s(static_cast(data().data()), element_count()); @@ -1872,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case S8: { + auto s8_data = data(); + TF_RET_CHECK(proto.s8s().size() == s8_data.size()); + std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin()); + } break; case U8: { auto u8_data = data(); TF_RET_CHECK(proto.u8s().size() == u8_data.size()); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index b928cb637494dec220a0912fdea96ed25cde13ef..1e0a2ad0ddf81d6813942c77ae273e2ce24e735e 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -217,31 +217,20 @@ class LiteralBase { // Converts this literal to the given shape. Returns an error is the // conversion is not possible. - // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. - // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + StatusOr ConvertToShape(const Shape& dest_shape) const; // Converts this literal to another primitive type using a bitcast // conversion. The to and from primitive types must have the same bit // width. Returns an error if the conversion is not possible. This literal // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; + StatusOr BitcastConvert(PrimitiveType primitive_dest_type) const; // Converts this literal to another primitive type. Returns an error if the // conversion is not possible. This literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; + StatusOr Convert(PrimitiveType primitive_dest_type) const; - // Clones the underlying buffers into a new Literal, or new - // std::unique_ptr. + // Clones the underlying buffers into a new Literal. Literal Clone() const; - std::unique_ptr CloneToUnique() const; // TODO(b/67651157): The methods below which perform computation on Literals // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with @@ -259,24 +248,23 @@ class LiteralBase { // Note: this is useful when the client wants to ensure that a value placed in // the XLA allocation tracker has a particular layout; for efficiency // purposes or avoiding unimplemented operation/layout combinations. - std::unique_ptr Relayout(const Layout& new_layout, - const ShapeIndex& shape_index = {}) const; + Literal Relayout(const Layout& new_layout, + const ShapeIndex& shape_index = {}) const; // An overload of Relayout which changes the layout of the entire shape rather // than being limited to a single array within the shape. - std::unique_ptr Relayout(const Shape& shape_with_layout) const; + Literal Relayout(const Shape& shape_with_layout) const; // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. // This literal must be an array. - StatusOr> Reshape( - absl::Span dimensions) const; + StatusOr Reshape(absl::Span dimensions) const; // Creates a new literal by broadcasting this literal with `dimensions` to // yield a literal of shape `result_shape`. - StatusOr> Broadcast( - const Shape& result_shape, absl::Span dimensions) const; + StatusOr Broadcast(const Shape& result_shape, + absl::Span dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -285,7 +273,7 @@ class LiteralBase { // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. // This literal must be an array. - std::unique_ptr Transpose(absl::Span permutation) const; + Literal Transpose(absl::Span permutation) const; // Creates a sub-array from this literal by extracting the indices // [start_index, limit_index) of each dimension. The result literal has the @@ -293,15 +281,15 @@ class LiteralBase { // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. // This literal must be an array. - std::unique_ptr Slice(absl::Span start_indices, - absl::Span limit_indices) const; + Literal Slice(absl::Span start_indices, + absl::Span limit_indices) const; // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. // This literal must be an array. template - std::unique_ptr Replicate(int64 times) const; + Literal Replicate(int64 times) const; // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive @@ -312,7 +300,7 @@ class LiteralBase { // initialization, then reinitialization. Conside if a call to // absl::make_unique(shape), followed by the call to // MutableLiteralBase::Populate can be used instead. - static std::unique_ptr CreateFromShape(const Shape& shape); + static Literal CreateFromShape(const Shape& shape); protected: // A data structure representing a subshape at a particular ShapeIndex within @@ -539,8 +527,8 @@ class LiteralBase { private: template - std::unique_ptr SliceInternal( - const Shape& result_shape, absl::Span start_indices) const; + Literal SliceInternal(const Shape& result_shape, + absl::Span start_indices) const; }; // Abstract base class representing a mutable literal in XLA. @@ -687,8 +675,7 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span elements); // Serialize from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + static StatusOr CreateFromProto(const LiteralProto& proto); protected: // Returns the piece at the given ShapeIndex. @@ -1137,15 +1124,14 @@ void MutableLiteralBase::PopulateWithValue(NativeT value) { } template -std::unique_ptr LiteralBase::Replicate(int64 times) const { +Literal LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = absl::make_unique( - ShapeUtil::MakeShape(shape().element_type(), bounds)); - int64 elements = ShapeUtil::ElementsIn(literal->shape()); + Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds)); + int64 elements = ShapeUtil::ElementsIn(literal.shape()); if (elements == 0) { return literal; } @@ -1157,7 +1143,7 @@ std::unique_ptr LiteralBase::Replicate(int64 times) const { bool done = false; while (!done) { const auto element = Get(input_indices); - literal->Set(output_indices, element); + literal.Set(output_indices, element); done = true; for (int n = 0; n < output_indices.size(); ++n) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 1a64594db86af31dcc196725d4b4f2a3ad9e4746..7ad287c8973367fb04583e6911ff75e76bdf5f1e 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -92,48 +92,48 @@ class LiteralUtilTest : public ::testing::Test { Layout layout_r3_dim0minor_; Layout layout_r4_dim0major_; Layout layout_r4_dim0minor_; - std::unique_ptr literal_r4_2x2x3x3_dim0major_; - std::unique_ptr literal_r4_2x2x3x3_dim0minor_; + Literal literal_r4_2x2x3x3_dim0major_; + Literal literal_r4_2x2x3x3_dim0minor_; }; TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - EXPECT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit.ToString()); auto false_lit = LiteralUtil::CreateR0(false); - EXPECT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit.ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - EXPECT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit.ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - EXPECT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit.ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - EXPECT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit.ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit.ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - EXPECT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit.ToString()); // 3.14 will be rounded to 3.14062 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.14062", bf16_lit_truncated->ToString()); + ASSERT_EQ("3.14062", bf16_lit_truncated.ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - EXPECT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - EXPECT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec.ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -143,7 +143,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -157,13 +157,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - EXPECT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal.ToString()); } TEST_F(LiteralUtilTest, TupleToString) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -171,7 +171,7 @@ f32[2,2] { { 3, 4 } } ))"; - EXPECT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple.ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -187,8 +187,8 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { // clang-format on auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -220,10 +220,10 @@ TEST_F(LiteralUtilTest, CreateSparse) { }; std::vector expected_values = {8, 9, 7, 10}; - EXPECT_EQ(literal->sparse_indices()->data(), + EXPECT_EQ(literal.sparse_indices()->data(), absl::Span(expected_indices.data(), expected_indices.num_elements())); - EXPECT_EQ(literal->data(), absl::Span(expected_values)); + EXPECT_EQ(literal.data(), absl::Span(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { @@ -234,8 +234,8 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on - EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = literal->ToString(); + EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2)); + string result = literal.ToString(); const string expected = R"(f32[1,2,3,2] { { /*i0=0*/ { /*i1=0*/ @@ -254,9 +254,9 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { - EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), + EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = literal_r4_2x2x3x3_dim0major_->ToString(); + string result = literal_r4_2x2x3x3_dim0major_.ToString(); const string expected = R"(f32[2,2,3,3] { { /*i0=0*/ { /*i1=0*/ @@ -294,7 +294,7 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { }); // clang-format on std::vector> seen; - literal->EachCellAsString( + literal.EachCellAsString( [&seen](absl::Span indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -310,14 +310,14 @@ TEST_F(LiteralUtilTest, ScalarEquality) { auto f32_42 = LiteralUtil::CreateR0(42.0); auto f32_42_clone = LiteralUtil::CreateR0(42.0); - EXPECT_EQ(*f32_42, *f32_42); - EXPECT_EQ(*f32_42, *f32_42_clone); + EXPECT_EQ(f32_42, f32_42); + EXPECT_EQ(f32_42, f32_42_clone); auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*f32_42, *f32_123); + EXPECT_NE(f32_42, f32_123); auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_NE(*f32_42, *f64_42); + EXPECT_NE(f32_42, f64_42); } TEST_F(LiteralUtilTest, NonScalarEquality) { @@ -330,12 +330,12 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto scalar = LiteralUtil::CreateR0(1.0); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(*matrix, *matrix); - EXPECT_EQ(*matrix, *matrix_clone); - EXPECT_NE(*matrix, *matrix_different); - EXPECT_NE(*matrix, *vector_literal); - EXPECT_NE(*matrix, *scalar); - EXPECT_NE(*matrix, nil); + EXPECT_EQ(matrix, matrix); + EXPECT_EQ(matrix, matrix_clone); + EXPECT_NE(matrix, matrix_different); + EXPECT_NE(matrix, vector_literal); + EXPECT_NE(matrix, scalar); + EXPECT_NE(matrix, nil); EXPECT_EQ(nil, nil); } @@ -344,57 +344,54 @@ TEST_F(LiteralUtilTest, TokenEquality) { auto token1 = LiteralUtil::CreateToken(); auto scalar = LiteralUtil::CreateR0(1.0); - EXPECT_EQ(*token0, *token1); - EXPECT_NE(*token0, *scalar); + EXPECT_EQ(token0, token1); + EXPECT_NE(token0, scalar); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}), - *LiteralUtil::MakeTuple({token0.get()})); - EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({token1.get(), scalar.get()})); - EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}), - *LiteralUtil::MakeTuple({scalar.get(), token1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0}), + LiteralUtil::MakeTuple({&token0})); + EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&token1, &scalar})); + EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}), + LiteralUtil::MakeTuple({&scalar, &token1})); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); - colmajor->Set({0, 0}, 1.0); - colmajor->Set({0, 1}, 2.0); - colmajor->Set({1, 0}, 3.0); - colmajor->Set({1, 1}, 4.0); + Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); + colmajor.Set({0, 0}, 1.0); + colmajor.Set({0, 1}, 2.0); + colmajor.Set({1, 0}, 3.0); + colmajor.Set({1, 1}, 4.0); - auto rowmajor = absl::make_unique( - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); - rowmajor->Set({0, 0}, 1.0); - rowmajor->Set({0, 1}, 2.0); - rowmajor->Set({1, 0}, 3.0); - rowmajor->Set({1, 1}, 4.0); + Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); + rowmajor.Set({0, 0}, 1.0); + rowmajor.Set({0, 1}, 2.0); + rowmajor.Set({1, 0}, 3.0); + rowmajor.Set({1, 1}, 4.0); - EXPECT_EQ(*rowmajor, *colmajor); + EXPECT_EQ(rowmajor, colmajor); } TEST_F(LiteralUtilTest, TupleEquality) { // Test equality with tuples. auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_EQ(*tuple1, *tuple2); + auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix}); + EXPECT_EQ(tuple1, tuple2); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_NE(*tuple1, *reversed_tuple); + auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar}); + EXPECT_NE(tuple1, reversed_tuple); // Tuple with different value. auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_NE(*tuple1, *different_tuple); + auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix}); + EXPECT_NE(tuple1, different_tuple); } TEST_F(LiteralUtilTest, C64Equality) { @@ -405,162 +402,161 @@ TEST_F(LiteralUtilTest, C64Equality) { // tuple, the other is a clone of the element in the original tuple. auto vector_clone = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); - EXPECT_EQ(*vector, *vector_clone); + EXPECT_EQ(vector, vector_clone); auto vector_reversed = LiteralUtil::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); - EXPECT_NE(*vector, *vector_reversed); + EXPECT_NE(vector, vector_reversed); } TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = LiteralUtil::CreateR0(0.0); auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto tuple = LiteralUtil::MakeTuple({&element1, &element1}); // Tuples should always return false for IsAll. - EXPECT_FALSE(tuple->IsAll(0)); - EXPECT_FALSE(tuple->IsAll(1)); + EXPECT_FALSE(tuple.IsAll(0)); + EXPECT_FALSE(tuple.IsAll(1)); } // Verifies that CreateFromShape works for tuples. TEST_F(LiteralUtilTest, CreateFromShapeTuple) { auto scalar = LiteralUtil::CreateR0(0.0); auto matrix = LiteralUtil::CreateR2({{0, 0}, {0, 0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); - auto x = Literal::CreateFromShape(tuple->shape()); - EXPECT_EQ(*tuple, *x); + auto x = Literal::CreateFromShape(tuple.shape()); + EXPECT_EQ(tuple, x); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::CreateR0(false)->IsAll(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(true)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(1)); - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(2)); - EXPECT_FALSE(LiteralUtil::CreateR0(true)->IsAll(-1)); + EXPECT_TRUE(LiteralUtil::CreateR0(false).IsAll(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(true).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(1)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(2)); + EXPECT_FALSE(LiteralUtil::CreateR0(true).IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE(LiteralUtil::CreateR0(255)->IsAll(int8_min)); + EXPECT_FALSE(LiteralUtil::CreateR0(255).IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::CreateR0(42.0)->IsAll(42)); - EXPECT_FALSE(LiteralUtil::CreateR0(42.0001)->IsAll(42)); + EXPECT_TRUE(LiteralUtil::CreateR0(42.0).IsAll(42)); + EXPECT_FALSE(LiteralUtil::CreateR0(42.0001).IsAll(42)); - EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100})->IsAll(100)); - EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001})->IsAll(100)); + EXPECT_TRUE(LiteralUtil::CreateR1({100, 100, 100}).IsAll(100)); + EXPECT_FALSE(LiteralUtil::CreateR1({100, 100, 100.001}).IsAll(100)); - EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{8, 8}, {8, 8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{8, 8}, {8, 9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{9, 8}, {8, 8}}).IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{h8}, {h8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h8}, {h9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{h9}, {h8}}).IsAll(8)); bfloat16 b8(8.0f); bfloat16 b9(9.0f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}})->IsAll(8)); - EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}})->IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b8}, {b8}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b8}, {b9}}).IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{b9}, {b8}}).IsAll(8)); // 9.001 will be truncated to 9.0 bfloat16 b91(9.001f); bfloat16 b90(9.00f); - EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}})->IsAll(9.0)); + EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); complex64 c8_9 = {8, 9}; - EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(LiteralUtil::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) - ->IsAll(-1)); + .IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); EXPECT_TRUE(LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}) - ->IsAllFloat(.5)); + .IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::CreateR0(.5)->IsAllFloat(.5)); - EXPECT_TRUE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.5)); - EXPECT_FALSE(LiteralUtil::CreateR0(-.5)->IsAllFloat(-.49)); + EXPECT_TRUE(LiteralUtil::CreateR0(0).IsAllFloat(0)); + EXPECT_TRUE(LiteralUtil::CreateR0(.5).IsAllFloat(.5)); + EXPECT_TRUE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.5)); + EXPECT_FALSE(LiteralUtil::CreateR0(-.5).IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); + LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsAllComplex) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR0(false)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); - EXPECT_FALSE(LiteralUtil::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(false).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); + EXPECT_FALSE(LiteralUtil::CreateR0(0).IsAllComplex(0)); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c7_9}}) - ->IsAllComplex({8.0f, 9.0f})); + .IsAllComplex({8.0f, 9.0f})); } TEST_F(LiteralUtilTest, IsAllFirst) { // IsAllComplex always returns false when the literal is not complex. - EXPECT_FALSE(LiteralUtil::CreateR1({false, true})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({false, false})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); - EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5})->IsAllFirst()); - EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({false, true}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({false, false}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR1({5, 5, 5, 5}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR1({1, 1, 2}).IsAllFirst()); complex64 c8_9 = {8, 9}; complex64 c7_9 = {7, 9}; - EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); - EXPECT_FALSE( - LiteralUtil::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); + EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); + EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); } TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = LiteralUtil::CreateR0(0.0f); auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(scalar_zero->IsZero({})); - EXPECT_FALSE(scalar_one->IsZero({})); + EXPECT_TRUE(scalar_zero.IsZero({})); + EXPECT_FALSE(scalar_one.IsZero({})); auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(array->IsZero({0, 1})); - EXPECT_TRUE(array->IsZero({0, 2})); - EXPECT_TRUE(array->IsZero({1, 1})); - EXPECT_FALSE(array->IsZero({1, 2})); + EXPECT_FALSE(array.IsZero({0, 1})); + EXPECT_TRUE(array.IsZero({0, 2})); + EXPECT_TRUE(array.IsZero({1, 1})); + EXPECT_FALSE(array.IsZero({1, 2})); auto complex_zero = LiteralUtil::CreateR0(0.0f); auto complex_nonzero = LiteralUtil::CreateR0(0.5f); - EXPECT_TRUE(complex_zero->IsZero({})); - EXPECT_FALSE(complex_nonzero->IsZero({})); + EXPECT_TRUE(complex_zero.IsZero({})); + EXPECT_FALSE(complex_nonzero.IsZero({})); } template @@ -576,19 +572,19 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = data->Relayout(layout01); - EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_EQ(*data, *data01); + auto data01 = data.Relayout(layout01); + EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01)); + EXPECT_EQ(data, data01); - auto data10 = data->Relayout(layout10); - EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_EQ(*data, *data10); + auto data10 = data.Relayout(layout10); + EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10)); + EXPECT_EQ(data, data10); } TEST_F(LiteralUtilTest, ReshapeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, ReshapeR4) { @@ -606,9 +602,9 @@ TEST_F(LiteralUtilTest, ReshapeR4) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { @@ -626,15 +622,15 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_EQ(*expected, *reshape); + EXPECT_EQ(expected, reshape); } TEST_F(LiteralUtilTest, TransposeR0) { auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = original->Transpose(/*permutation=*/{}); - EXPECT_EQ(*original, *reshape); + auto reshape = original.Transpose(/*permutation=*/{}); + EXPECT_EQ(original, reshape); } TEST_F(LiteralUtilTest, TransposeR4) { @@ -646,10 +642,10 @@ TEST_F(LiteralUtilTest, TransposeR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); + auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1}); - reshape->EachCell([&](absl::Span indices, float value) { - EXPECT_EQ(value, original->Get( + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get( {indices[2], indices[3], indices[0], indices[1]})); }); } @@ -658,35 +654,35 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. auto dim0minor_relaid_to_dim0major = - literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major); + literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major); auto dim0major_relaid_to_dim0minor = - literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); - EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor); + literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_); + EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->element_count(), 6); - EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_EQ(mat_dim0minor.element_count(), 6); + EXPECT_THAT(mat_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->data(), + auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_); + EXPECT_THAT(relaid_mat_to_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). auto mat_dim0major = LiteralUtil::CreateR2WithLayout( {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->element_count(), 6); - EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_EQ(mat_dim0major.element_count(), 6); + EXPECT_THAT(mat_dim0major.data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->data(), + auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_); + EXPECT_THAT(relaid_mat_to_dim0minor.data(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -707,77 +703,77 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->element_count(), 12); + EXPECT_EQ(lit_dim0minor.element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->data(), + EXPECT_THAT(lit_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->data(), + EXPECT_THAT(relaid_lit_to_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->element_count(), 12); - EXPECT_THAT(lit_dim0major->data(), + EXPECT_EQ(lit_dim0major.element_count(), 12); + EXPECT_THAT(lit_dim0major.data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->data(), + auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_); + EXPECT_THAT(relaid_lit_to_dim0minor.data(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { auto input = LiteralUtil::CreateR0(1); - auto result = input->Slice({}, {}); - EXPECT_EQ(*input, *result); + auto result = input.Slice({}, {}); + EXPECT_EQ(input, result); } TEST_F(LiteralUtilTest, SliceR1F32) { auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = input->Slice({3}, {4}); + auto result = input.Slice({3}, {4}); auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR2U32) { auto input_3x4 = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto result = input_3x4.Slice({0, 2}, {2, 4}); auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_EQ(*expected, *result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, SliceR3U32Full) { auto input_2x3x2 = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); - EXPECT_EQ(*input_2x3x2, *result); + auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_EQ(input_2x3x2, result); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = LiteralUtil::CreateR1({77}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { @@ -785,7 +781,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = LiteralUtil::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { @@ -793,7 +789,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { bfloat16 h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { @@ -801,7 +797,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { bfloat16 h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { @@ -809,28 +805,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { bfloat16 h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output(ShapeUtil::MakeShape(F32, {})); output.PopulateWithValue(2.5f); auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output(ShapeUtil::MakeShape(S64, {3})); output.PopulateWithValue(-7); auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output(ShapeUtil::MakeShape(U64, {2, 2})); output.PopulateWithValue(42); auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { @@ -838,7 +834,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { output.PopulateWithValue({4, 2}); auto expected = LiteralUtil::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { @@ -846,7 +842,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { half h(0.25f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { @@ -854,7 +850,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { half h(0.5f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { @@ -862,18 +858,18 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { half h(2.0f); output.PopulateWithValue(h); auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = input->Replicate(3); + auto output = input.Replicate(3); auto expected = LiteralUtil::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_EQ(*output, *expected); + EXPECT_EQ(output, expected); } TEST_F(LiteralUtilTest, CopySliceFrom) { @@ -889,17 +885,17 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](absl::Span indexes) { - source->Set(indexes, ++seqnr); + source.Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step, init_proc); auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); @@ -911,12 +907,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = blank->Get(blank_indexes); - matched = (bval != 0 && bval == source->Get(source_indexes)); + auto bval = blank.Get(blank_indexes); + matched = (bval != 0 && bval == source.Get(source_indexes)); return matched; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, + ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); } @@ -925,14 +921,14 @@ TEST_F(LiteralUtilTest, CopySliceFrom) { TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = LiteralUtil::CreateR0(0); auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(zero->CopyFrom(*nine)); - EXPECT_EQ(*zero, *nine); + TF_EXPECT_OK(zero.CopyFrom(nine)); + EXPECT_EQ(zero, nine); auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); - EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); - EXPECT_EQ(vect->Get({4}), 17); + TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {})); + EXPECT_EQ(zero.Get({}), 17); + TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {})); + EXPECT_EQ(vect.Get({4}), 17); } TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { @@ -945,17 +941,17 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); - EXPECT_EQ(*nine, *const_nine); + TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0})); + EXPECT_EQ(nine, const_nine); } { // Copy 0 element to destination with zero elements. - const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = LiteralUtil::CreateR1({9}); - TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); - EXPECT_EQ(*empty, *const_empty); + TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0})); + EXPECT_EQ(empty, const_empty); } } @@ -969,74 +965,75 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) { TEST_F(LiteralUtilTest, CopyFromArrays) { auto scalar_42 = LiteralUtil::CreateR0(42.0); auto scalar_123 = LiteralUtil::CreateR0(123.0); - EXPECT_NE(*scalar_42, *scalar_123); - TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*scalar_42, *scalar_123); - EXPECT_EQ(scalar_42->Get({}), 123.0f); + EXPECT_NE(scalar_42, scalar_123); + TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(scalar_42, scalar_123); + EXPECT_EQ(scalar_42.Get({}), 123.0f); auto matrix_1234 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix_5678 = LiteralUtil::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); - EXPECT_NE(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); - TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, - /*src_shape_index=*/{})); - EXPECT_EQ(*matrix_1234, *matrix_5678); - EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); + EXPECT_NE(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(matrix_1234, matrix_5678); + EXPECT_EQ(matrix_1234.Get({0, 0}), 5.0f); } TEST_F(LiteralUtilTest, CopyFromTuples) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {matrix.get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get()}); + Literal inner_elements[] = {LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0})}; + Literal inner_tuple = LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}); + Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple}); // Create a tuple the same shape as the inner tuple of nested_tuple but with // different values.. - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(-5).get(), - LiteralUtil::CreateR1({2.0, 4.0}).get(), &nil_literal}); + Literal int32_minus5 = LiteralUtil::CreateR0(-5); + Literal double_2_4 = LiteralUtil::CreateR1({2.0, 4.0}); + Literal tuple = + LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal}); - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 44.0); // Overwrite the inner tuple element of nested_tuple with the contents of // 'tuple'. - TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{})); + TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); + EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. - EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); - EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); - EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); + EXPECT_EQ(nested_tuple.Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple.Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple.Get({1}, {1, 1}), 4.0); } TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { - auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0(-2).get(), - LiteralUtil::CreateR0(4).get()}); + Literal elements[] = {LiteralUtil::CreateR0(-2), + LiteralUtil::CreateR0(4)}; + Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), 4); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), 4); // Copy from one element to the other. - TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, - /*src_shape_index=*/{0})); + TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); - EXPECT_EQ(tuple->Get({}, {0}), -2); - EXPECT_EQ(tuple->Get({}, {1}), -2); + EXPECT_EQ(tuple.Get({}, {0}), -2); + EXPECT_EQ(tuple.Get({}, {1}), -2); } TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto vector = LiteralUtil::CreateR1({5.0, 7.0}); - Status status = matrix->CopyFrom(*vector); + Status status = matrix.CopyFrom(vector); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); @@ -1046,9 +1043,8 @@ TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); - Literal* l1 = m1.get(); - const char* d1 = reinterpret_cast(l1->data().data()); + Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + const char* d1 = reinterpret_cast(m1.data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -1061,8 +1057,7 @@ TEST_F(LiteralUtilTest, F16) { half h1(1.0f); half h2(2.0f); auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l2 = m2.get(); - const char* d2 = reinterpret_cast(l2->data().data()); + const char* d2 = reinterpret_cast(m2.data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -1091,25 +1086,25 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->Populate(generator)); + TF_EXPECT_OK(literal.Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1133,25 +1128,25 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = absl::make_unique(shape); + Literal literal(shape); auto generator = [&](absl::Span indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(), indexes) + 17; }; - TF_EXPECT_OK(literal->PopulateParallel(generator)); + TF_EXPECT_OK(literal.PopulateParallel(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](absl::Span indexes) { - auto value = literal->Get(indexes); + auto value = literal.Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; - ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step, + ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step, check_function); EXPECT_TRUE(matched); } @@ -1170,10 +1165,9 @@ TEST_F(LiteralUtilTest, ConvertR4) { {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->Convert(U32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32)); - EXPECT_EQ(*expected, *converted); + EXPECT_EQ(expected, converted); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { @@ -1245,69 +1239,65 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); // clang-format on - std::unique_ptr conv; + Literal conv; - conv = s8->Convert(U32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u32); + conv = s8.Convert(U32).ConsumeValueOrDie(); + EXPECT_EQ(conv, u32); - conv = s8->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = s8.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s8->Convert(U64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *u64); + conv = s8.Convert(U64).ConsumeValueOrDie(); + EXPECT_EQ(conv, u64); - conv = s8->Convert(S64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s64); + conv = s8.Convert(S64).ConsumeValueOrDie(); + EXPECT_EQ(conv, s64); - conv = s8->Convert(PRED).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *pred); + conv = s8.Convert(PRED).ConsumeValueOrDie(); + EXPECT_EQ(conv, pred); - conv = bf16->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = bf16.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = bf16->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = bf16.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = pred->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *int32_pred); + conv = pred.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, int32_pred); - conv = f32->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f32.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = f64->Convert(S32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *s32); + conv = f64.Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(conv, s32); - conv = s32->Convert(F32).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f32); + conv = s32.Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(conv, f32); - conv = f32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = f64->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = f64.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = s32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = u32->Convert(F16).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *f16); + conv = u32.Convert(F16).ConsumeValueOrDie(); + EXPECT_EQ(conv, f16); - conv = s32->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = s32.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - conv = f16->Convert(C64).ConsumeValueOrDie(); - EXPECT_EQ(*conv, *c64); + conv = f16.Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(conv, c64); - EXPECT_EQ(s32->Convert(TUPLE).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(S16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(s32->Convert(U16).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(F32).status().code(), - tensorflow::error::UNIMPLEMENTED); - EXPECT_EQ(c64->Convert(S32).status().code(), + EXPECT_EQ(s32.Convert(TUPLE).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -1317,13 +1307,12 @@ TEST_F(LiteralUtilTest, BitcastConvert) { tensorflow::bit_cast(100.f), 0xbeef}); auto expected = LiteralUtil::CreateR1( {2.5f, -42.25f, 100.0f, tensorflow::bit_cast(0xbeef)}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr converted, - original->BitcastConvert(F32)); + TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32)); } TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); - Status status = literal->BitcastConvert(F64).status(); + Status status = literal.BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); EXPECT_TRUE( absl::StrContains(status.error_message(), "bit widths are different")); @@ -1341,11 +1330,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { p.add_preds((i % 2) == (len % 2)); } - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - ASSERT_EQ(len, literal->data().size()); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal.data().size()); int i = 0; - for (bool value : literal->data()) { + for (bool value : literal.data()) { EXPECT_EQ((i % 2) == (len % 2), value); ++i; } @@ -1358,11 +1346,10 @@ TEST_F(LiteralUtilTest, ToProto_f16) { half h2(2.0f); auto m = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); - Literal* l = m.get(); - EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->data().size()); + EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape())); + EXPECT_EQ(4, m.data().size()); - LiteralProto p = l->ToProto(); + LiteralProto p = m.ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); EXPECT_EQ(8, p.f16s().size()); const char* d = p.f16s().data(); @@ -1389,9 +1376,8 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, - Literal::CreateFromProto(p)); - auto r = literal->data(); + TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p)); + auto r = literal.data(); ASSERT_EQ(4, r.size()); EXPECT_EQ(h1, r[0]); EXPECT_EQ(h2, r[1]); @@ -1402,43 +1388,41 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); - EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); - EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(scalar, {}), scalar); + EXPECT_EQ(LiteralSlice(matrix, {}), matrix); + EXPECT_EQ(LiteralSlice(tuple, {}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple); EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(tuple, {0}), scalar); + EXPECT_EQ(LiteralSlice(tuple, {1}), matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar); + EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix); + EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar); } TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralSlice(*nested_tuple); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 1.0f); + const auto nested_tuple_view = LiteralSlice(nested_tuple); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); - nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); - EXPECT_EQ( - nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), - 555.0f); + nested_tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ(nested_tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 555.0f); @@ -1447,14 +1431,14 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) { TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); - auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); + auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); + auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar}); - const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(nested_tuple); const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { @@ -1497,9 +1481,8 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { } TEST_F(LiteralUtilTest, LiteralMove) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - Literal literal(std::move(*matrix)); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(matrix)); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1511,17 +1494,21 @@ TEST_F(LiteralUtilTest, LiteralMove) { TEST_F(LiteralUtilTest, DecomposeTuple) { Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get(), &nil_literal}) - .get(), - &nil_literal}); - - EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); - std::vector elements = nested_tuple->DecomposeTuple(); - EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + Literal inner_elements[] = { + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR1({23.0, 44.0}), + }; + Literal tuple_elements[] = { + LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::MakeTuple( + {&inner_elements[0], &inner_elements[1], &nil_literal}), + }; + Literal nested_tuple = LiteralUtil::MakeTuple( + {&tuple_elements[0], &tuple_elements[1], &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape())); + std::vector elements = nested_tuple.DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape())); ASSERT_EQ(elements.size(), 3); @@ -1552,13 +1539,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { TEST_F(LiteralUtilTest, MoveIntoTuple) { std::vector elements; - elements.push_back(std::move(*LiteralUtil::CreateR0(1.0))); - elements.push_back(std::move(*LiteralUtil::CreateR1({4, 8}))); - elements.push_back(std::move(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR1({23.0, 44.0}).get()}) - - )); + elements.push_back(LiteralUtil::CreateR0(1.0)); + elements.push_back(LiteralUtil::CreateR1({4, 8})); + std::vector inner_elements; + inner_elements.push_back(LiteralUtil::CreateR0(42)); + inner_elements.push_back(LiteralUtil::CreateR1({23.0, 44.0})); + elements.push_back( + LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]})); Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements)); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); @@ -1586,9 +1573,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { Literal literal; EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - literal = std::move(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(matrix); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); @@ -1599,9 +1585,8 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { } TEST_F(LiteralUtilTest, LiteralSliceCopy) { - std::unique_ptr matrix = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralSlice(*matrix); + Literal matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralSlice(matrix); LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); @@ -1611,45 +1596,43 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) { } TEST_F(LiteralUtilTest, GetSetTuple) { - auto tuple = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(42.0).get(), - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); - tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); - - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), - 3.0); - tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); - EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + Literal elements[] = { + LiteralUtil::CreateR0(42.0), + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + }; + auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple.Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0); + tuple.Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple.Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), -4.0); } TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { // Literals constructed using CreateFromShape should be zero initialized. - std::unique_ptr scalar_f32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); - EXPECT_EQ(scalar_f32->Get({}), 0.0); - EXPECT_TRUE(scalar_f32->IsAll(0)); - - std::unique_ptr vector_s32 = - Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); - EXPECT_EQ(vector_s32->Get({0}), 0); - EXPECT_EQ(vector_s32->Get({1}), 0); - EXPECT_EQ(vector_s32->Get({2}), 0); - EXPECT_TRUE(vector_s32->IsAll(0)); - - std::unique_ptr tuple = - Literal::CreateFromShape(ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), - ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); - - EXPECT_EQ(tuple->Get({}, {0}), 0.0); - EXPECT_EQ(tuple->Get({0}, {1}), false); - EXPECT_EQ(tuple->Get({1}, {1}), false); - EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); - EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); + Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32.Get({}), 0.0); + EXPECT_TRUE(scalar_f32.IsAll(0)); + + Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32.Get({0}), 0); + EXPECT_EQ(vector_s32.Get({1}), 0); + EXPECT_EQ(vector_s32.Get({2}), 0); + EXPECT_TRUE(vector_s32.IsAll(0)); + + Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple.Get({}, {0}), 0.0); + EXPECT_EQ(tuple.Get({0}, {1}), false); + EXPECT_EQ(tuple.Get({1}, {1}), false); + EXPECT_EQ(tuple.Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple.Get({}, {3}), complex64(0.0f, 0.0f)); } TEST_F(LiteralUtilTest, ProtoRoundTrip) { @@ -1657,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto one_f32 = LiteralUtil::CreateR0(1.0); auto two_f32 = LiteralUtil::CreateR0(2.0); auto vector_int8 = LiteralUtil::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_uint8 = LiteralUtil::CreateR1({128, 0, 2, 56, 127, 255}); auto vector_c64 = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); auto vector_bfloat16 = LiteralUtil::CreateR1( {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); @@ -1665,25 +1649,27 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto tuple = LiteralUtil::MakeTuple( - {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + {&one_f32, &vector_half, &matrix_pred, &matrix_pred}); Literal nil_literal(ShapeUtil::MakeNil()); - auto nested_tuple = LiteralUtil::MakeTuple( - {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + auto nested_tuple = + LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal}); auto to_from_proto = [](const Literal& literal) -> Literal { - return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + return Literal::CreateFromProto(literal.ToProto()).ValueOrDie(); }; - EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); - EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); - EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); - EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); - EXPECT_EQ(*tuple, to_from_proto(*tuple)); - EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(one_f32, to_from_proto(one_f32)); + EXPECT_EQ(vector_int8, to_from_proto(vector_int8)); + EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8)); + EXPECT_EQ(vector_c64, to_from_proto(vector_c64)); + EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); + EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); + EXPECT_EQ(tuple, to_from_proto(tuple)); + EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple)); EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); - EXPECT_NE(*one_f32, *two_f32); - EXPECT_NE(*one_f32, to_from_proto(*two_f32)); + EXPECT_NE(one_f32, two_f32); + EXPECT_NE(one_f32, to_from_proto(two_f32)); } TEST_F(LiteralUtilTest, InvalidProtoNoValues) { @@ -1802,11 +1788,11 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { TEST_F(LiteralUtilTest, SortSparseElements) { auto literal = LiteralUtil::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); - literal->AppendSparseElement({2, 3, 4}, 2.0); - literal->AppendSparseElement({3, 4, 5}, 3.0); - literal->AppendSparseElement({1, 2, 3}, 1.0); - literal->SortSparseElements(); - EXPECT_EQ(literal->ToString(false), + literal.AppendSparseElement({2, 3, 4}, 2.0); + literal.AppendSparseElement({3, 4, 5}, 3.0); + literal.AppendSparseElement({1, 2, 3}, 1.0); + literal.SortSparseElements(); + EXPECT_EQ(literal.ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1816,57 +1802,54 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), "false"); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(int64{2})); EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(double{2.0})); EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat(static_cast(half{2.0}))); EXPECT_EQ(LiteralUtil::CreateSparse( dimensions, indices, std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), + .GetSparseElementAsString(1), absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{0})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 1}, {2, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 1}, {2, 2}})); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { - std::unique_ptr literal = LiteralUtil::CreateR1({1, 2}); + Literal literal = LiteralUtil::CreateR1({1, 2}); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), - /*dimensions=*/{1})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{1, 2}, {1, 2}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{1, 2}, {1, 2}})); } TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { - std::unique_ptr literal = LiteralUtil::CreateR0(9); + Literal literal = LiteralUtil::CreateR0(9); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr broadcasted_literal, - literal->Broadcast( - /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), - /*dimensions=*/{})); - EXPECT_EQ(*broadcasted_literal, - *LiteralUtil::CreateR2({{9, 9}, {9, 9}})); + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(broadcasted_literal, + LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } } // namespace diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 613449cf10c785de55e8474c0ee35f78e8ed92b4..0cb1ae35f4ad31f091063d78ed32c1463be8ee0a 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -45,7 +45,7 @@ using absl::StrCat; // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template -std::unique_ptr ConvertType(LiteralSlice literal) { +Literal ConvertType(LiteralSlice literal) { // First construct shape of the result. Shape result_shape(literal.shape()); ShapeUtil::ForEachMutableSubshape( @@ -56,7 +56,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()); } }); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); // Then copy over the data from 'literal' converting FromNativeT values to // ToNativeT values as necessary. @@ -67,14 +67,14 @@ std::unique_ptr ConvertType(LiteralSlice literal) { if (subshape.element_type() == primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); - auto dest = result->data(shape_index); + auto dest = result.data(shape_index); for (int64 i = 0; i < src.size(); ++i) { dest[i] = static_cast(src[i]); } } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); + TF_CHECK_OK(result.CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); } } }); @@ -83,53 +83,52 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } // namespace -/* static */ std::unique_ptr LiteralUtil::CreateFromDimensions( +/* static */ Literal LiteralUtil::CreateFromDimensions( PrimitiveType primitive_type, absl::Span dimensions) { return Literal::CreateFromShape( ShapeUtil::MakeShape(primitive_type, dimensions)); } -/* static */ std::unique_ptr LiteralUtil::ConvertBF16ToF32( +/* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); } -/* static */ std::unique_ptr LiteralUtil::ConvertF32ToBF16( +/* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); } -/* static */ std::unique_ptr LiteralUtil::CreateToken() { - return absl::make_unique(ShapeUtil::MakeTokenShape()); +/* static */ Literal LiteralUtil::CreateToken() { + return Literal(ShapeUtil::MakeTokenShape()); } /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case U64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S8: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case S64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(0.0f))); + return LiteralUtil::CreateR0(static_cast(0.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case F64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case C64: - return std::move(*LiteralUtil::CreateR0(0)); + return LiteralUtil::CreateR0(0); case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -145,30 +144,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case U64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S8: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case S64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F16: - return std::move(*LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case BF16: - return std::move( - *LiteralUtil::CreateR0(static_cast(1.0f))); + return LiteralUtil::CreateR0(static_cast(1.0f)); case F32: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case F64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case C64: - return std::move(*LiteralUtil::CreateR0(1)); + return LiteralUtil::CreateR0(1); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -184,42 +182,36 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::min())); + return LiteralUtil::CreateR0(std::numeric_limits::min()); case F32: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - -std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + -std::numeric_limits::infinity()); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return std::move(*LiteralUtil::CreateR0(false)); + return LiteralUtil::CreateR0(false); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(-std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -232,40 +224,34 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: - return std::move( - *LiteralUtil::CreateR0(std::numeric_limits::max())); + return LiteralUtil::CreateR0(std::numeric_limits::max()); case F32: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case F64: - return std::move(*LiteralUtil::CreateR0( - std::numeric_limits::infinity())); + return LiteralUtil::CreateR0( + std::numeric_limits::infinity()); case PRED: - return std::move(*LiteralUtil::CreateR0(true)); + return LiteralUtil::CreateR0(true); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case BF16: - return std::move(*LiteralUtil::CreateR0( - static_cast(std::numeric_limits::infinity()))); + return LiteralUtil::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -275,31 +261,29 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } } -/* static */ std::unique_ptr LiteralUtil::CreateR1( +/* static */ Literal LiteralUtil::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = absl::make_unique( + Literal literal( ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR1U8( - absl::string_view value) { - auto literal = absl::make_unique( - ShapeUtil::MakeShape(U8, {static_cast(value.size())})); +/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { + Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { - literal->Set({i}, value[i]); + literal.Set({i}, value[i]); } return literal; } -/* static */ std::unique_ptr LiteralUtil::CreateR2F32Linspace( - float from, float to, int64 rows, int64 cols) { +/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to, + int64 rows, int64 cols) { auto value = MakeLinspaceArray2D(from, to, rows, cols); return CreateR2FromArray2D(*value); } -/* static */ std::unique_ptr LiteralUtil::ReshapeSlice( +/* static */ Literal LiteralUtil::ReshapeSlice( absl::Span new_dimensions, absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; @@ -309,13 +293,13 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = absl::make_unique( + Literal new_literal( ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); + Shape shape_with_layout = new_literal.shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Copy data into new literal, element-by-element. @@ -326,40 +310,40 @@ std::unique_ptr ConvertType(LiteralSlice literal) { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); + new_literal.Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " @@ -376,97 +360,82 @@ std::unique_ptr ConvertType(LiteralSlice literal) { CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0); switch (literal.shape().element_type()) { case PRED: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 8 bit types. case S8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U8: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 16 bit types. case BF16: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U16: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 32 bit types. case F32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U32: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); // 64 bit types. case C64: - return std::move(*LiteralUtil::CreateR0( - literal.GetFirstElement())); + return LiteralUtil::CreateR0( + literal.GetFirstElement()); case F64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case S64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); case U64: - return std::move( - *LiteralUtil::CreateR0(literal.GetFirstElement())); + return LiteralUtil::CreateR0(literal.GetFirstElement()); default: LOG(FATAL) << "Unhandled primitive type " << literal.shape().element_type(); } } -/* static */ std::unique_ptr LiteralUtil::MakeTuple( +/* static */ Literal LiteralUtil::MakeTuple( absl::Span elements) { std::vector element_shapes; for (const auto* element : elements) { element_shapes.push_back(element->shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleFromSlices( +/* static */ Literal LiteralUtil::MakeTupleFromSlices( absl::Span elements) { std::vector element_shapes; for (const auto& element : elements) { element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int i = 0; i < elements.size(); ++i) { - TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; } -/* static */ std::unique_ptr LiteralUtil::MakeTupleOwned( - std::vector> elements) { +/* static */ Literal LiteralUtil::MakeTupleOwned( + std::vector elements) { std::vector element_shapes; element_shapes.reserve(elements.size()); for (const auto& element : elements) { - element_shapes.push_back(element->shape()); + element_shapes.push_back(element.shape()); } - auto literal = - absl::make_unique(ShapeUtil::MakeTupleShape(element_shapes)); + Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); for (int64 i = 0; i < elements.size(); ++i) { TF_CHECK_OK( - literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 2d6084a67a3b966d054103df0f06ddb82d0d6525..2b181621ed92be8952ccec19e0d4229c494b9f47 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -69,36 +69,34 @@ class LiteralUtil { // The variants not ending with WithLayout use the default XLA layout for the // literal's linear representation in memory. template - static std::unique_ptr CreateR0(NativeT value); + static Literal CreateR0(NativeT value); template - static std::unique_ptr CreateR1(absl::Span values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); + static Literal CreateR1(absl::Span values); + static Literal CreateR1(const tensorflow::core::Bitmap& values); template - static std::unique_ptr CreateR2( + static Literal CreateR2( std::initializer_list> values); template - static std::unique_ptr CreateR2WithLayout( + static Literal CreateR2WithLayout( std::initializer_list> values, const Layout& layout); template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); + static Literal CreateR3(std::initializer_list< + std::initializer_list>> + values); template - static std::unique_ptr CreateR3WithLayout( + static Literal CreateR3WithLayout( std::initializer_list< std::initializer_list>> values, const Layout& layout); template - static std::unique_ptr CreateR4( + static Literal CreateR4( std::initializer_list>>> values); template - static std::unique_ptr CreateR4WithLayout( + static Literal CreateR4WithLayout( std::initializer_list>>> values, @@ -139,9 +137,10 @@ class LiteralUtil { // [9, 10, 11]: 4.0 // template - static std::unique_ptr CreateSparse( - absl::Span dimensions, SparseIndexArray indices, - absl::Span values, bool sort = true); + static Literal CreateSparse(absl::Span dimensions, + SparseIndexArray indices, + absl::Span values, + bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -155,130 +154,120 @@ class LiteralUtil { static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithDescendingLayout( + static Literal CreateFullWithDescendingLayout( absl::Span dimensions, NativeT value); // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template - static std::unique_ptr CreateFromArray(const Array& values); + static Literal CreateFromArray(const Array& values); template - static std::unique_ptr CreateFromArrayWithLayout( - const Array& values, const Layout& layout); + static Literal CreateFromArrayWithLayout(const Array& values, + const Layout& layout); template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values); + static Literal CreateR2FromArray2D(const Array2D& values); template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); + static Literal CreateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout); template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values); + static Literal CreateR3FromArray3D(const Array3D& values); template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); + static Literal CreateR3FromArray3DWithLayout(const Array3D& values, + const Layout& layout); template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values); + static Literal CreateR4FromArray4D(const Array4D& values); template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); + static Literal CreateR4FromArray4DWithLayout(const Array4D& values, + const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(absl::string_view value); + static Literal CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols); + static Literal CreateR2F32Linspace(float from, float to, int64 rows, + int64 cols); // Creates a literal that projects the (x, y) dimensions given in values into // the z dimension given by "projection". template - static std::unique_ptr CreateR3Projected( + static Literal CreateR3Projected( std::initializer_list> values, int64 projection); // Creates a literal that projects the (x, y) dimensions given in values into // the z and p dimensions given. template - static std::unique_ptr CreateR4Projected( + static Literal CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z); // Returns an identity matrix (rank 2) with the given row and column count. template - static std::unique_ptr MakeIdentityR2(int64 size); + static Literal MakeIdentityR2(int64 size); // Returns a tuple literal composed of given literals. Data is copied from the // given elements into the returned literal. - static std::unique_ptr MakeTuple( - absl::Span elements); + static Literal MakeTuple(absl::Span elements); - static std::unique_ptr MakeTupleFromSlices( - absl::Span elements); + static Literal MakeTupleFromSlices(absl::Span elements); // As above, but intended to be invoked with move semantics; i.e. // - // std::vector> elements = ...; + // std::vector elements = ...; // auto result = LiteralUtil::MakeTupleOwned(std::move(elements)); // // This would have been declared as an overload, but there is ambiguity // in invocation between the above signature and this one. - static std::unique_ptr MakeTupleOwned( - std::vector> elements); + static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of unique_ptrs to + // This overload lets you pass a braced list of Literals to // MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // - // Simply relying on the MakeTupleOwned(std::vector>) + // Simply relying on the MakeTupleOwned(std::vector) // overload doesn't work because std::initializer_list's elements are always // const. // - // The arguments to this function must all be unique_ptr. + // The arguments to this function must all be Literal. template - static std::unique_ptr MakeTupleOwned( - std::unique_ptr... elements) { - std::array, sizeof...(Ts)> arr{ - std::move(elements)...}; - std::vector> v; + static Literal MakeTupleOwned(Ts... elements) { + std::array arr{std::move(elements)...}; + std::vector v; v.insert(v.begin(), std::make_move_iterator(arr.begin()), std::make_move_iterator(arr.end())); return MakeTupleOwned(std::move(v)); } // Create a constant token literal. Token types have no value. - static std::unique_ptr CreateToken(); + static Literal CreateToken(); // Creates a new Literal object with its values havings the primitive_type // type, and with dimensions defined by the dimensions parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, absl::Span dimensions); + static Literal CreateFromDimensions(PrimitiveType primitive_type, + absl::Span dimensions); // If the given literal's data type is bfloat16, converts it to a float // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32( - const LiteralSlice& bf16_literal); + static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); // If the given literal's data type is float, converts it to a bfloat16 // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16( - const LiteralSlice& f32_literal); + static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); // Creates a literal with a new shape with the given new dimensions using the // data in the given input literal. For reshaping purposes the (flat) data // buffer of the input literal is assumed to have the given minor_to_major // layout order. - static std::unique_ptr ReshapeSlice( - absl::Span new_dimensions, - absl::Span minor_to_major, const LiteralSlice& literal); + static Literal ReshapeSlice(absl::Span new_dimensions, + absl::Span minor_to_major, + const LiteralSlice& literal); // Creates a literal with the supplied shape, and uses the provided value // generator to populate the literal's values. @@ -286,7 +275,7 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( + static StatusOr CreateRandomLiteral( const Shape& shape, const std::function)>& generator); @@ -297,8 +286,8 @@ class LiteralUtil { template < PrimitiveType type, typename E, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, E* engine, + T mean, T stddev); // Creates a literal with the supplied shape, and initializes the literal // values using a normal distribution with given mean and stddev standard @@ -307,8 +296,8 @@ class LiteralUtil { template < PrimitiveType type, typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); + static StatusOr CreateRandomLiteral(const Shape& shape, T mean, + T stddev); // // End of factory methods. @@ -322,44 +311,43 @@ class LiteralUtil { std::ostream& operator<<(std::ostream& out, const Literal& literal); template -/* static */ std::unique_ptr LiteralUtil::CreateR0(NativeT value) { - auto literal = absl::make_unique(ShapeUtil::MakeShape( +/* static */ Literal LiteralUtil::CreateR0(NativeT value) { + Literal literal(ShapeUtil::MakeShape( primitive_util::NativeToPrimitiveType(), {})); - literal->Set({}, value); + literal.Set({}, value); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR1( - absl::Span values) { - auto literal = absl::make_unique( +/* static */ Literal LiteralUtil::CreateR1(absl::Span values) { + Literal literal( ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), {static_cast(values.size())})); - literal->PopulateR1(values); + literal.PopulateR1(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2WithLayout( +/* static */ Literal LiteralUtil::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), {static_cast(values.size()), static_cast(values.begin()->size())}, AsInt64Slice(layout.minor_to_major()))); - literal->PopulateR2(values); + literal.PopulateR2(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR2( +/* static */ Literal LiteralUtil::CreateR2( std::initializer_list> values) { return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3WithLayout( +/* static */ Literal LiteralUtil::CreateR3WithLayout( std::initializer_list>> values, const Layout& layout) { @@ -384,14 +372,14 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR3( +/* static */ Literal LiteralUtil::CreateR3( std::initializer_list>> values) { return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); } template -/* static */ std::unique_ptr LiteralUtil::CreateR4WithLayout( +/* static */ Literal LiteralUtil::CreateR4WithLayout( std::initializer_list>>> values, @@ -422,23 +410,22 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateSparse( +/* static */ Literal LiteralUtil::CreateSparse( absl::Span dimensions, SparseIndexArray indices, absl::Span values, bool sort) { int64 num_elements = values.size(); int64 rank = dimensions.size(); CHECK_EQ(num_elements, indices.index_count()); CHECK_EQ(rank, indices.rank()); - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithSparseLayout( - primitive_util::NativeToPrimitiveType(), dimensions, - indices.max_indices())); - literal->PopulateSparse(indices, values, sort); + Literal literal(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal.PopulateSparse(indices, values, sort); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateR4( +/* static */ Literal LiteralUtil::CreateR4( std::initializer_list>>> values) { @@ -446,50 +433,48 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArrayWithLayout( +/* static */ Literal LiteralUtil::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = absl::make_unique(ShapeUtil::MakeShapeWithLayout( + Literal literal(ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), values.dimensions(), AsInt64Slice(layout.minor_to_major()))); - literal->PopulateFromArray(values); + literal.PopulateFromArray(values); return literal; } template -/* static */ std::unique_ptr LiteralUtil::CreateFromArray( +/* static */ Literal LiteralUtil::CreateFromArray( const Array& values) { return CreateFromArrayWithLayout( values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR2FromArray2D( +/* static */ Literal LiteralUtil::CreateR2FromArray2D( const Array2D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout( + const Array3D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3FromArray3D( +/* static */ Literal LiteralUtil::CreateR3FromArray3D( const Array3D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr LiteralUtil::CreateR3Projected( +/* static */ Literal LiteralUtil::CreateR3Projected( std::initializer_list> values, int64 projection) { int64 dim0_size = projection; @@ -514,7 +499,7 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4Projected( +/* static */ Literal LiteralUtil::CreateR4Projected( std::initializer_list> values, int64 projection_p, int64 projection_z) { int64 dim0_size = projection_p; @@ -542,21 +527,20 @@ template } template -/* static */ std::unique_ptr LiteralUtil::CreateR4FromArray4D( +/* static */ Literal LiteralUtil::CreateR4FromArray4D( const Array4D& values) { return CreateFromArray(values); } template -/* static */ std::unique_ptr -LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { +/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout( + const Array4D& values, const Layout& layout) { return CreateFromArrayWithLayout(values, layout); } // Returns an identity matrix (rank 2) with the given row and column count. template -/* static */ std::unique_ptr LiteralUtil::MakeIdentityR2(int64 size) { +/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) { Array2D array(size, size, 0); for (int64 i = 0; i < size; ++i) { array(i, i) = 1; @@ -565,33 +549,29 @@ template } template -/* static */ std::unique_ptr -LiteralUtil::CreateFullWithDescendingLayout(absl::Span dimensions, - NativeT value) { - auto literal = - absl::make_unique(ShapeUtil::MakeShapeWithDescendingLayout( - primitive_util::NativeToPrimitiveType(), dimensions)); - literal->PopulateWithValue(value); +/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout( + absl::Span dimensions, NativeT value) { + Literal literal(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal.PopulateWithValue(value); return literal; } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral( +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( const Shape& shape, const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - auto literal = absl::make_unique(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( + Literal literal(shape); + TF_RETURN_IF_ERROR(literal.Populate( [&](absl::Span indexes) { return generator(indexes); })); return std::move(literal); } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; std::normal_distribution generator(mean, stddev); return CreateRandomLiteral( @@ -600,8 +580,8 @@ LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, } template -/* static */ StatusOr> -LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { +/* static */ StatusOr LiteralUtil::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { std::minstd_rand0 engine; return CreateRandomLiteral(shape, &engine, mean, stddev); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index f9473d372bb15058d7413e2ac8a303dd34322180..0f86f9f35e105713aa3072a9ebf572d33d35d66d 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -39,8 +39,8 @@ PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file) PackedLiteralReader::~PackedLiteralReader() { delete file_; } -StatusOr> PackedLiteralReader::Read( - const Shape& shape, const Layout* layout) { +StatusOr PackedLiteralReader::Read(const Shape& shape, + const Layout* layout) { VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); @@ -57,11 +57,11 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type())); } - auto result = absl::make_unique(literal_shape); - result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + Literal result(literal_shape); + result.PopulateWithValue(std::numeric_limits::quiet_NaN()); int64 elements = ShapeUtil::ElementsIn(shape); - absl::Span field = result->data(); + absl::Span field = result.data(); char* data = absl::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); absl::string_view sp; diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h index 98dccaa9a246520bf60217b96d67a13a24c34b4a..d6d2ff1521bab341b166c4f5c1dc0917e28573d8 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.h +++ b/tensorflow/compiler/xla/packed_literal_reader.h @@ -41,8 +41,7 @@ class PackedLiteralReader { // // Layout is optional. If it is not provided, no layout is set on the literal // that is produced. - StatusOr> Read(const Shape& shape, - const Layout* layout = nullptr); + StatusOr Read(const Shape& shape, const Layout* layout = nullptr); // Returns whether the input file has been fully exhausted; i.e. all available // packed literals have been read and we're at the end of the file. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cd6e20b69366c064e20c6e0a7d1aebe6229690d8..9da5dc0d2d40cb10640fb0fd2c4c65b4f8e55346 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -81,8 +81,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, return client->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr> TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number) { +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number) { VLOG(1) << "Outfeeding literal from replica number: " << replica_number << " shape: " << shape; LocalClient* client = GetOrCreateLocalClient(); @@ -141,9 +141,8 @@ StatusOr LocalShapedBuffer::FromLiteral( LocalClient* client = GetOrCreateLocalClient(); StatusOr buf = [&] { if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, relaid); } return ToBuffer(client, /*device_ordinal=*/0, argument); }(); @@ -151,7 +150,7 @@ StatusOr LocalShapedBuffer::FromLiteral( return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -StatusOr> LocalShapedBuffer::ToLiteral() const { +StatusOr LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); return client->ShapedBufferToLiteral(*shaped_buffer()); } @@ -160,7 +159,7 @@ CompiledLocalComputation::CompiledLocalComputation( std::unique_ptr executable) : executable_(std::move(executable)) {} -StatusOr> CompiledLocalComputation::Execute( +StatusOr CompiledLocalComputation::Execute( const std::vector& arguments, const std::vector>& shapes_with_layout) { LocalClient* client = GetOrCreateLocalClient(); @@ -169,7 +168,7 @@ StatusOr> CompiledLocalComputation::Execute( // Each replica populates a StatusOr result, but only replica zero actually // retrieves its literal value. - std::vector>> results(GetReplicaCount()); + std::vector> results(GetReplicaCount()); { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", GetReplicaCount()); @@ -198,9 +197,8 @@ StatusOr> CompiledLocalComputation::Execute( StatusOr pushed; if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - pushed = ToBuffer(client, device_ordinal, *relaid); + Literal relaid = argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, relaid); } else { pushed = ToBuffer(client, device_ordinal, argument); } diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 78b3c598b97294d2ba4deb72ec9c1251ef68b7cf..1d5dfe591175735d58a5fe555fffc8043fa4de7e 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -51,8 +51,8 @@ Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); // Transfers a literal of the given shape from the outfeed of the given replica. // // The replica number is resolved to an appropriate device ordinal. -StatusOr > TransferFromOutfeedLocalReplica( - const Shape& shape, int replica_number); +StatusOr TransferFromOutfeedLocalReplica(const Shape& shape, + int replica_number); // Wraps a ScopedShapedBuffer produced by copying a literal "to // device," i.e. copying a literal to a scoped buffer via the local @@ -65,7 +65,7 @@ class LocalShapedBuffer { LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - StatusOr > ToLiteral() const; + StatusOr ToLiteral() const; // Transfers ownership of the encapsulated ShapedBuffer to the caller, // analogous to std::unique_ptr::release(). @@ -117,7 +117,7 @@ class CompiledLocalComputation { // with optionally-specified argument layouts. The literals will be // re-laid out according to the corresponding elements of // shapes_with_layout. - StatusOr > Execute( + StatusOr Execute( const std::vector& arguments, const std::vector >& shapes_with_layout); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 76c09512d82006af35e2508ce8e60f23a4c056c3..521490e76c138553c5cc6895412eadb35a939881 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,12 +109,12 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" -#include "third_party/absl/strings/str_cat.h" -#include "third_party/absl/strings/str_format.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/python/numpy_bridge.h" #include "tensorflow/compiler/xla/python/local_computation_builder.h" @@ -216,9 +216,9 @@ tensorflow::ImportNumpy(); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if ($1.ok()) { - std::unique_ptr value = $1.ConsumeValueOrDie(); + Literal value = $1.ConsumeValueOrDie(); $result = numpy::PyObjectFromXlaLiteral(*value); } else { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); @@ -346,25 +346,25 @@ tensorflow::ImportNumpy(); // Literal -%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { +%typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); SWIG_fail; } - $1 = literal_status.ValueOrDie().get(); + $1 = &literal_status.ValueOrDie(); } -%typemap(out) std::unique_ptr { +%typemap(out) Literal { $result = numpy::PyObjectFromXlaLiteral(*$1); } -%typemap(out) StatusOr< std::unique_ptr > { +%typemap(out) StatusOr { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); + $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); } %typemap(in) const std::vector& (std::vector temps) { @@ -375,13 +375,13 @@ tensorflow::ImportNumpy(); const int size = PySequence_Size($input); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + StatusOr literal_status = numpy::XlaLiteralFromPyObject(o); if (!literal_status.ok()) { PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); Py_DECREF(o); SWIG_fail; } - temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + temps.push_back(literal_status.ConsumeValueOrDie()); Py_DECREF(o); } $1 = &temps; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index fc6511bef566cb6f4e0d4e52972954de0792e959..b0aa024c7474cf8e6934432b2f364be464714999 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -368,10 +368,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { } } -StatusOr> XlaLiteralFromPyObject(PyObject* o) { +StatusOr XlaLiteralFromPyObject(PyObject* o) { if (PyTuple_Check(o)) { int num_elements = PyTuple_Size(o); - std::vector> elements; + std::vector elements; elements.reserve(num_elements); for (int i = 0; i < num_elements; i++) { PyObject* element = PyTuple_GetItem(o, i); @@ -389,8 +389,7 @@ StatusOr> XlaLiteralFromPyObject(PyObject* o) { int np_type = PyArray_TYPE(py_array); auto literal = LiteralUtil::CreateFromDimensions( NumpyTypeToPrimitiveType(np_type), dimensions); - TF_RETURN_IF_ERROR( - CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal)); return std::move(literal); } else { return InvalidArgument( diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 8cae1751853f3cd18033ecf6edca40bf99c6d917..40ff2d9ad214cc4dcad42234fa296834cbc92882 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -82,7 +82,7 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // To avoid transferring ownership of the data buffers that underlie // PyArrays and XLA literals, this function makes deep copies of all // array data. -StatusOr > XlaLiteralFromPyObject(PyObject* o); +StatusOr XlaLiteralFromPyObject(PyObject* o); // The following functions copy array data from the buffers underlying Numpy // ndarrays into those underlying XLA literals, and vice versa. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a4854f593f0a579e3461b35033620e762593c6a6..ceb5e74db7c3b9305e9d77068df9ae0a3690af8a 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D& input, /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); @@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric( } /* static */ std::unique_ptr> -ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, - float init, - const absl::Span& window, - const absl::Span& stride, +ReferenceUtil::ReduceWindow1DAdd(absl::Span operand, float init, + absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{static_cast(operand.size())}; @@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span& operand, ReferenceUtil::ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.height(), operand.width()}; std::vector window_counts(window.size(), 0); @@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; std::vector dim_lengths{operand.height(), operand.width()}; return ReduceWindow2DGeneric( @@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); @@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + absl::Span window, absl::Span stride, + Padding padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; return ReduceWindow4DGeneric( @@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding) { + absl::Span window, absl::Span stride, + absl::Span> padding) { std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3(), operand.n4()}; @@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric( } /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding) { + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding) { const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, padding); @@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric( ReferenceUtil::SelectAndScatter4DGePlus(const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, + absl::Span window, + absl::Span stride, bool same_padding) { Padding padding = same_padding ? Padding::kSame : Padding::kValid; auto result = absl::make_unique>(operand.n1(), operand.n2(), @@ -529,13 +522,13 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); + lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)); std::vector> paddings = MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, @@ -546,7 +539,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim; dim.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0))); dim.set_stride(kernel_stride.first); dim.set_padding_low(paddings[0].first); dim.set_padding_high(paddings[0].second); @@ -556,7 +549,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( WindowDimension dim2; dim2.set_size( - rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); + rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1))); dim2.set_stride(kernel_stride.second); dim2.set_padding_low(paddings[1].first); dim2.set_padding_high(paddings[1].second); @@ -564,35 +557,39 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = - ShapeInference::InferConvolveShape(lhs_literal->shape(), - rhs_literal->shape(), window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = ShapeInference::InferConvolveShape( + lhs_literal.shape(), rhs_literal.shape(), + /*feature_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; - std::unique_ptr result_literal = + Literal result_literal = evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); - CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); + CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4); auto result = - absl::make_unique>(result_literal->shape().dimensions(0), - result_literal->shape().dimensions(1), - result_literal->shape().dimensions(2), - result_literal->shape().dimensions(3)); + absl::make_unique>(result_literal.shape().dimensions(0), + result_literal.shape().dimensions(1), + result_literal.shape().dimensions(2), + result_literal.shape().dimensions(3)); result->Each([&](absl::Span indices, float* value) { - *value = result_literal->Get(indices); + *value = result_literal.Get(indices); }); return result; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 9ce098029dbc35f6b4bab2efd77bee2b7e1a6255..8654fbb9b5e16c5ac13cb29aafeef8d142dbe39f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -177,47 +177,41 @@ class ReferenceUtil { // Windowed reductions with Add as the function to apply. static std::unique_ptr> ReduceWindow1DAdd( - const absl::Span& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span operand, float init, + absl::Span window, absl::Span stride, + Padding padding); static std::unique_ptr> ReduceWindow2DAdd( - const Array2D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array2D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow3DAdd( - const Array3D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array3D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( - const Array4D& operand, float init, - const absl::Span& window, - const absl::Span& stride, Padding padding); + const Array4D& operand, float init, absl::Span window, + absl::Span stride, Padding padding); // Windowed reductions with a generic reduce function. static std::unique_ptr> ReduceWindow1DGeneric( - const absl::Span& operand, float init, + absl::Span operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow2DGeneric( const Array2D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, Padding padding); + absl::Span window, absl::Span stride, + Padding padding); // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, - const absl::Span& window, - const absl::Span& stride, - const absl::Span>& padding); + absl::Span window, absl::Span stride, + absl::Span> padding); // Batch normalize data. static std::unique_ptr> BatchNorm4D( @@ -230,8 +224,8 @@ class ReferenceUtil { // TODO(b/74533103) Switch tests to evaluator and remove this implementation. static std::unique_ptr> SelectAndScatter4DGePlus( const Array4D& operand, const Array4D& source, float init, - const absl::Span& window, - const absl::Span& stride, bool same_padding); + absl::Span window, absl::Span stride, + bool same_padding); // Concatenates the lhs and rhs arrays along the concatenate_dimension. // E.g. if concatenate_dimension is 0, the "n1"/height dimension is @@ -332,8 +326,8 @@ class ReferenceUtil { // Slices with index clamping template - static std::vector ClampSlice1D(const absl::Span& input, - int64 start, int64 size) { + static std::vector ClampSlice1D(absl::Span input, int64 start, + int64 size) { start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 3ec0192148492c2516bf1c14fd4b960b08014388..a1b0f4045ff071454451f9fe3942ac974f4f47ac 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -55,7 +55,7 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MatmulArray2D) { @@ -67,14 +67,14 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({6.f, 15.f}, actual_literal, ErrorSpec(0.0001)); } @@ -82,7 +82,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); auto actual_literal = LiteralUtil::CreateR1(*result); - LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, + LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, actual_literal, ErrorSpec(0.0001)); } @@ -90,14 +90,14 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { auto result = LiteralUtil::CreateR1(ReferenceUtil::Reduce4DTo1D( Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, [](float a, float b) { return a + b; })); - LiteralTestUtil::ExpectR1Equal({0}, *result); + LiteralTestUtil::ExpectR1Equal({0}, result); } TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, + LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal, ErrorSpec(0.0001)); } @@ -108,7 +108,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, MapArray4D) { @@ -121,7 +121,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -138,7 +138,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -146,16 +146,16 @@ TEST_F(ReferenceUtilTest, SliceArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 2.f}, {4.f, 5.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray2D) { auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); - LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, - *actual_literal, ErrorSpec(0.0001)); + LiteralTestUtil::ExpectR2Near({{1.f, 3.f}, {4.f, 6.f}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray3D) { @@ -167,7 +167,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal, ErrorSpec(0.0001)); } @@ -180,8 +180,8 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result); LiteralTestUtil::ExpectR3Near( - {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, - *actual_literal, ErrorSpec(0.0001)); + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal, + ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceArray4D) { @@ -194,7 +194,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, SliceStridedArray4D) { @@ -208,7 +208,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) { LiteralTestUtil::ExpectR4Near( {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { @@ -220,7 +220,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -233,7 +233,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) { auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual); - LiteralTestUtil::ExpectR3NearArray3D(expected, *actual_literal, + LiteralTestUtil::ExpectR3NearArray3D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -268,7 +268,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -302,7 +302,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -358,7 +358,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -411,7 +411,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); - LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, + LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal, ErrorSpec(0.0001)); } @@ -424,7 +424,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, - *actual_literal, ErrorSpec(0.0001)); + actual_literal, ErrorSpec(0.0001)); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 43fd8fe1bd0f41eb2ac5c42021a8ca4f63282646..84fe5b17d10fba8c9f44314bec2b827e98ff6b33 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -95,12 +95,11 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal expected_literal = LiteralUtil::CreateR1(expected); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal, ErrorSpec(0.0001))); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 26b48cf4196ce24a8a20f407f698d951e18193f9..fb80c78f6852db7d69aeef752b5f692d47d58bed 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -87,6 +87,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -159,6 +161,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep ], @@ -291,6 +294,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -303,6 +307,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -331,6 +336,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -347,6 +354,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -397,6 +405,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -493,6 +502,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -541,6 +551,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -563,6 +574,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -989,6 +1001,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1006,8 +1019,8 @@ cc_library( ":buffer_value_containers", ":heap_simulator", ":hlo", + ":hlo_memory_scheduler", ":hlo_proto", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1035,8 +1048,8 @@ tf_cc_test( ":cpu_plugin", ":flatten_call_graph", ":hlo", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1049,6 +1062,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1081,14 +1095,15 @@ tf_cc_test( deps = [ ":hlo", ":hlo_dataflow_analysis", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1123,12 +1138,45 @@ tf_cc_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) +cc_library( + name = "hlo_module_group", + srcs = ["hlo_module_group.cc"], + hdrs = ["hlo_module_group.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_module_group_test", + srcs = ["hlo_module_group_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":hlo_module_group", + ":hlo_parser", + ":hlo_proto", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_module_group_metadata", srcs = ["hlo_module_group_metadata.cc"], @@ -1169,14 +1217,35 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( - name = "hlo_scheduling", - srcs = ["hlo_scheduling.cc"], - hdrs = ["hlo_scheduling.h"], + name = "hlo_memory_scheduler", + srcs = ["hlo_memory_scheduler.cc"], + hdrs = ["hlo_memory_scheduler.h"], deps = [ ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1190,21 +1259,22 @@ cc_library( ) tf_cc_test( - name = "hlo_scheduling_test", - srcs = ["hlo_scheduling_test.cc"], + name = "hlo_memory_scheduler_test", + srcs = ["hlo_memory_scheduler_test.cc"], deps = [ ":heap_simulator", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", ":hlo_parser", - ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1229,6 +1299,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1362,6 +1433,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -1678,6 +1750,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -1747,6 +1820,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", @@ -1922,6 +1996,9 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_memory_scheduler", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1930,6 +2007,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -2203,6 +2281,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2281,6 +2360,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], ) @@ -2361,12 +2441,11 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", - ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", + ":hlo_memory_scheduler", ":hlo_ordering", - ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2395,6 +2474,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2461,6 +2541,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -2520,6 +2601,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2538,6 +2620,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -2576,6 +2659,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -2853,6 +2937,7 @@ tf_cc_test( deps = [ ":hlo_tfgraph_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", ], @@ -3187,6 +3272,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3289,6 +3375,8 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 7c078f07d72ab4243d50b7f7910cb7c794e306c4..5458159d149c627b1121fd8a30e073b712542390 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -205,7 +205,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(hlo->shape().element_type()).Clone())); HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( @@ -296,6 +296,14 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return scalar_add_computation_; } + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + StatusOr FoldConvInputPad(HloInstruction* convolution); + StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to use a kDot in place of the given convolution. + StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -312,7 +320,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot strength reduction on platforms where it causes a slowdown. bool enable_dot_strength_reduction_; - // Disable convolution simplification on platforms where it causes a slowdown. + // Disable convolution -> dot simplification on platforms where it causes a + // slowdown. bool enable_conv_simplification_; // Cached computation for adding two scalar F32. @@ -527,7 +536,7 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + HloInstruction::CreateConstant(literal.Clone())); } } @@ -546,7 +555,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr unique_scalar = absl::make_unique( + Literal unique_scalar( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -676,7 +685,7 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } auto inverse = computation_->AddInstruction( - HloInstruction::CreateConstant((new_literal.CloneToUnique()))); + HloInstruction::CreateConstant((new_literal.Clone()))); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); @@ -950,9 +959,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); - new_dot->set_precision_config(dot.precision_config()); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1053,9 +1062,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); - memoized_inst->set_precision_config(dot->precision_config()); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1151,9 +1160,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); - new_dot->set_precision_config(dot->precision_config()); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1470,7 +1478,7 @@ Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) { auto* iota = Cast(instruction); if (iota->shape().dimensions(iota->iota_dimension()) <= 1) { auto zero = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique())); + LiteralUtil::Zero(iota->shape().element_type()).Clone())); return ReplaceWithNewInstruction( iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {})); } @@ -1573,7 +1581,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( - LiteralUtil::One(power->shape().element_type()).CloneToUnique()); + LiteralUtil::One(power->shape().element_type()).Clone()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -1608,7 +1616,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::One(rhs->shape().element_type()).CloneToUnique())); + LiteralUtil::One(rhs->shape().element_type()).Clone())); // Explicitly broadcast scalar 1 to the output shape, to avoid implicit // broadcast in divide HLO as we are trying to eliminate implicit @@ -2058,12 +2066,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (pad_literal == reduce_init_literal) { return true; } - auto converted_pad_literal = pad_literal.ConvertToShape( - reduce_init_value->shape(), /*round_f32_to_bf16=*/true); + auto converted_pad_literal = + pad_literal.ConvertToShape(reduce_init_value->shape()); if (!converted_pad_literal.ok()) { return false; } - return *converted_pad_literal.ValueOrDie() == reduce_init_literal; + return converted_pad_literal.ValueOrDie() == reduce_init_literal; }; // The pad value is usually a constant, so we handle that case and do not // try to get more fancy about proving equivalence in cases beyond that. @@ -2213,170 +2221,155 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleConvolution( +StatusOr AlgebraicSimplifierVisitor::FoldConvInputPad( HloInstruction* convolution) { - auto lhs = convolution->mutable_operand(0); - auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::IsZeroElementArray(lhs->shape()) || - ShapeUtil::IsZeroElementArray(rhs->shape())) { - return ReplaceWithNewInstruction( - convolution, - HloInstruction::CreateBroadcast( - convolution->shape(), - computation_->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(convolution->shape().element_type()) - .CloneToUnique())), - {})); - } - + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); const auto& window = convolution->window(); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - // Try to merge padding/dilation of the input with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr { - if (lhs->opcode() != HloOpcode::kPad) { + if (lhs->opcode() != HloOpcode::kPad) { + return false; + } + + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(lhs->operand(1), 0)) { + return false; + } + + const auto& padding = lhs->padding_config(); + + // Can't pad batch or feature dims. + for (int64 dim : + {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(lhs->operand(1), 0)) { + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = window; + for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); + // Edge padding composes with itself in the straightforward way, but + // composing interior padding is nontrivial, and we cowardly refuse to + // think about it. If we see interior padding in either the kPad or conv, + // bail if there's any sort of padding in the other. + if (p.interior_padding() != 0 && + (w.padding_low() != 0 || w.padding_high() != 0 || + w.base_dilation() != 1)) { + return false; + } + if (w.base_dilation() != 1 && + (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0)) { return false; } - const auto& padding = lhs->padding_config(); - - // Can't pad batch or feature dims. - for (int64 dim : - {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + w.set_padding_low(w.padding_low() + p.edge_padding_low()); + w.set_padding_high(w.padding_high() + p.edge_padding_high()); + if (p.interior_padding() != 0) { + CHECK_EQ(w.base_dilation(), 1); + w.set_base_dilation(1 + p.interior_padding()); } + } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = window; - for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim)); - // Edge padding composes with itself in the straightforward way, but - // composing interior padding is nontrivial, and we cowardly refuse to - // think about it. If we see interior padding in either the kPad or conv, - // bail if there's any sort of padding in the other. - if (p.interior_padding() != 0 && - (w.padding_low() != 0 || w.padding_high() != 0 || - w.base_dilation() != 1)) { - return false; - } - if (w.base_dilation() != 1 && - (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0)) { - return false; - } + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs->mutable_operand(0), rhs}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - w.set_padding_low(w.padding_low() + p.edge_padding_low()); - w.set_padding_high(w.padding_high() + p.edge_padding_high()); - if (p.interior_padding() != 0) { - CHECK_EQ(w.base_dilation(), 1); - w.set_base_dilation(1 + p.interior_padding()); - } - } +StatusOr AlgebraicSimplifierVisitor::FoldConvFilterPad( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs->mutable_operand(0), rhs}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + if (rhs->opcode() != HloOpcode::kPad) { + return false; + } - if (folded_input_pad) { - return Status::OK(); + // Convolution's padding is always zero, so bail if the kPad is adding + // something other than zero. + if (!IsAll(rhs->operand(1), 0)) { + return false; } - // Try to merge dilation of the filter with the convolution's window. - TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr { - if (rhs->opcode() != HloOpcode::kPad) { - return false; - } + const auto& padding = rhs->padding_config(); - // Convolution's padding is always zero, so bail if the kPad is adding - // something other than zero. - if (!IsAll(rhs->operand(1), 0)) { + // Can't pad or dilate feature dims. + for (int64 dim : {dnums.kernel_input_feature_dimension(), + dnums.kernel_output_feature_dimension()}) { + const auto& p = padding.dimensions(dim); + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || + p.interior_padding() != 0) { return false; } + } - const auto& padding = rhs->padding_config(); + // Compute the window which is the result of merging the kPad and the + // convolution's existing window. + Window new_window = convolution->window(); + for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { + auto& w = *new_window.mutable_dimensions(dim); + const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - // Can't pad or dilate feature dims. - for (int64 dim : {dnums.kernel_input_feature_dimension(), - dnums.kernel_output_feature_dimension()}) { - const auto& p = padding.dimensions(dim); - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 || - p.interior_padding() != 0) { - return false; - } + // We can only do this transformation if p adds dilation to the filter -- + // edge padding on the filter is not supported in conv. + if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { + return false; } - // Compute the window which is the result of merging the kPad and the - // convolution's existing window. - Window new_window = convolution->window(); - for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) { - auto& w = *new_window.mutable_dimensions(dim); - const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim)); - - // We can only do this transformation if p adds dilation to the filter -- - // edge padding on the filter is not supported in conv. - if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) { - return false; - } - - // Nothing to do if the kPad for this dim is entirely a nop. - if (p.interior_padding() == 0) { - continue; - } + // Nothing to do if the kPad for this dim is entirely a nop. + if (p.interior_padding() == 0) { + continue; + } - // We cowardly refuse to think about how dilation composes with itself; - // bail if both the kPad and conv have dilation on this dimension. - if (w.window_dilation() > 1) { - return false; - } - CHECK_EQ(w.window_dilation(), 1); - w.set_window_dilation(1 + p.interior_padding()); - w.set_size(rhs->operand(0)->shape().dimensions( - dnums.kernel_spatial_dimensions(dim))); + // We cowardly refuse to think about how dilation composes with itself; + // bail if both the kPad and conv have dilation on this dimension. + if (w.window_dilation() > 1) { + return false; } + CHECK_EQ(w.window_dilation(), 1); + w.set_window_dilation(1 + p.interior_padding()); + w.set_size(rhs->operand(0)->shape().dimensions( + dnums.kernel_spatial_dimensions(dim))); + } - auto new_conv = convolution->CloneWithNewOperands( - convolution->shape(), {lhs, rhs->mutable_operand(0)}); - new_conv->set_window(new_window); - TF_RETURN_IF_ERROR( - ReplaceWithNewInstruction(convolution, std::move(new_conv))); - return true; - }()); + auto new_conv = convolution->CloneWithNewOperands( + convolution->shape(), {lhs, rhs->mutable_operand(0)}); + new_conv->set_window(new_window); + TF_RETURN_IF_ERROR( + ReplaceWithNewInstruction(convolution, std::move(new_conv))); + return true; +} - if (folded_filter_pad) { - return Status::OK(); - } +StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( + HloInstruction* convolution) { + auto* lhs = convolution->mutable_operand(0); + auto* rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); if (!enable_conv_simplification_) { - return Status::OK(); + return false; } - // HandleConvolution tries to replace a convolution with a DOT instruction. - // - // Only add when bitcasts can be used: - // - if bitcasts are not supported, then reshapes could be used but will - // end up with another copy. - // - if bitcasts are supported, the simplifier will be called again with - // bitcasts_ == true. - // TODO(cwhipkey): b/31337498, make this layout insensitive. + // TODO(b/31337498): For now, we cowardly refuse to do this optimization in + // layout-insensitive mode, for fear of adding nontrivial reshapes. if (!is_layout_sensitive_) { - return Status::OK(); + return false; } const Shape& input_shape = lhs->shape(); @@ -2389,7 +2382,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Require the spatial dimensions in the kernel to have a bound of one. for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) { - return Status::OK(); + return false; } } @@ -2400,7 +2393,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // for a 1x1 window, so window dilation is no problem. if (window_util::HasStride(window) || window_util::HasPadding(window) || window_util::HasBaseDilation(window)) { - return Status::OK(); + return false; } // Also, the shapes must align for a rowmajor matmul: @@ -2426,7 +2419,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dnums.kernel_input_feature_dimension()) < PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { - return Status::OK(); + return false; } auto add_bitcast = [&](Shape shape, HloInstruction* operand) { @@ -2468,7 +2461,7 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( if (!valid_bitcast_callback_(input_shape, new_input_shape) || !valid_bitcast_callback_(filter_shape, new_filter_shape) || !valid_bitcast_callback_(dot_output_shape, convolution_shape)) { - return Status::OK(); + return false; } auto new_lhs = add_bitcast(new_input_shape, lhs); @@ -2477,10 +2470,47 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - dot->set_precision_config(convolution->precision_config()); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); + + TF_RETURN_IF_ERROR( + ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot))); + return true; +} + +Status AlgebraicSimplifierVisitor::HandleConvolution( + HloInstruction* convolution) { + // Zero-sized input or filter. + if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(convolution->shape().element_type()))), + {})); + } - return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); + // Try to merge padding/dilation of the input with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution)); + if (folded_input_pad) { + return Status::OK(); + } + + // Try to merge dilation of the filter with the convolution's window. + TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution)); + if (folded_filter_pad) { + return Status::OK(); + } + + // Try to replace the convolution with a kDot instruction. + TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution)); + if (replaced_with_dot) { + return Status::OK(); + } + + return Status::OK(); } bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 43a891e4fa163e833692a8e71b8f2f21d377e323..3fc1ba24271b40de0a24ed4c957cd83aca736f55 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1044,7 +1044,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_window_reversal(false); // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -2260,9 +2261,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - window, dnums) + /*feature_group_count=*/1, window, + dnums) .ValueOrDie(), - lhs_pad, filter, window, dnums)); + lhs_pad, filter, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2366,18 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) { rhs_pad->shape().dimensions(3), testcase.orig_conv_window)) .ValueOrDie(); - auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - window, dnums) - .ValueOrDie(), - input, rhs_pad, window, dnums)); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. - PrecisionConfigProto precision_config; - precision_config.add_operand_precision(PrecisionConfigProto::HIGH); - precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST); - orig_conv->set_precision_config(precision_config); + PrecisionConfig precision_config; + precision_config.add_operand_precision(PrecisionConfig::HIGH); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); + + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), + /*feature_group_count=*/1, window, + dnums) + .ValueOrDie(), + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + precision_config)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2396,9 +2401,10 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT( - conv->precision_config().operand_precision(), - ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST)); + EXPECT_THAT(Cast(conv) + ->precision_config() + .operand_precision(), + ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } @@ -2522,8 +2528,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); - b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, - window, dnums)); + b.AddInstruction(HloInstruction::CreateConvolve( + out_shape, input, filter, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = HloTestBase::CreateNewModule(); @@ -2901,7 +2908,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, + DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2924,9 +2932,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { HloComputation::Builder builder(TestName()); const float constant_scalar = 7.3f; std::initializer_list constant_vector = {1.1f, 2.0f, 3.3f}; - std::unique_ptr value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + Literal elements[] = {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector)}; + Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto computation = module().AddEntryComputation(builder.Build()); @@ -3253,8 +3261,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -3329,8 +3337,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3393,8 +3401,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3511,8 +3519,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 dot_row_size = 1; int64 dot_col_size = spec.n; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3581,8 +3589,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 dot_row_size = spec.m; int64 dot_col_size = 1; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index a16b85a0a5e3f72f54e9733bb974b01377e0c358..eda026ac5685dc469a6230094eb28b3618e36400 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); - new_dot->set_precision_config(batch_dot->precision_config()); + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index ec281ae68fe76bac4029058997c44b1f7e71aeae..30d33e0d3531bb5e931ebfa0b60c91523dd0cb44 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -205,11 +205,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( const Shape feature_shape = scale->shape(); auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = add(HloInstruction::CreateBroadcast( operand_shape, add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); @@ -331,7 +331,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( const Shape feature_shape = scale->shape(); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( operand_shape, computation_->AddInstruction( @@ -464,11 +464,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = LiteralUtil::CreateR0(0.0f); - TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon()); - TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype)); auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); auto epsilon_activation = add( @@ -560,7 +560,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( auto elements_per_feature_literal = LiteralUtil::CreateR0(elements_per_feature_int64); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); + elements_per_feature_literal.Convert(ptype)); auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aba0d9bb5b977d89656580df46838eefb8cd6662..f7ac8f5482908af104554a1cf812370b9098cda7 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 6363a21c3bafe8353a6ebfde405bb7a3736c2074..5f93740887aa7e61458990992fe0573883ff056d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -65,8 +65,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloVerifiedTestBase { protected: + BFloat16ConversionFoldingTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16ConversionFolding fold(&bfloat16_support_); @@ -102,7 +106,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -137,7 +141,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(mul0->shape().element_type(), F32); @@ -172,7 +176,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert2); EXPECT_EQ(sub0->shape().element_type(), F32); @@ -202,7 +206,7 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConversions(module.get())); + EXPECT_FALSE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), convert1); EXPECT_EQ(gte->shape().element_type(), F32); @@ -248,7 +252,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConversions(module.get())); + EXPECT_TRUE(FoldConversions(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_EQ(tuple->operand(0), gte_a); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index b08705d4c2b644fe1a7ba9994876fd6397f8a5df..cef0eba14e9dd463d6c32b047211bf25a84478f6 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -68,8 +68,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16NormalizationTest : public HloTestBase { +class BFloat16NormalizationTest : public HloVerifiedTestBase { protected: + BFloat16NormalizationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + bool Normalize(HloModule* module) { TestBFloat16Support bfloat16_support_; BFloat16Normalization normalization(&bfloat16_support_); @@ -105,7 +109,7 @@ TEST_F(BFloat16NormalizationTest, NoopIfSupported) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(Normalize(module.get())); + EXPECT_FALSE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), add1); EXPECT_EQ(add0->shape().element_type(), BF16); @@ -133,7 +137,7 @@ TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), mul1); @@ -163,7 +167,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(computation->root_instruction()->operand(0), sub1); @@ -201,7 +205,7 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), reduce); EXPECT_EQ(reduce->called_computations().size(), 1); @@ -259,7 +263,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -286,7 +290,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction(), gte); EXPECT_EQ(gte->shape().element_type(), BF16); @@ -308,13 +312,16 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(Normalize(module.get())); + EXPECT_TRUE(Normalize(module)); EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); EXPECT_EQ(dot->shape().element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 545a6ecfb1fca88c2c759e820f9d87a38b1941ca..58f78f8e24d0bc00a63e3583828cf8e01ae4531a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { continue; } if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) { - TF_ASSIGN_OR_RETURN( - auto converted_literal, - hlo->literal().ConvertToShape(hlo->shape(), - /*round_f32_to_bf16=*/true)); + TF_ASSIGN_OR_RETURN(auto converted_literal, + hlo->literal().ConvertToShape(hlo->shape())); auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 69b654d30e42b1ed69304206f09120e86831d468..e032b5c624c0151fd63c870e0f21ec97656d625f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloVerifiedTestBase { protected: + BFloat16PropagationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true) {} + // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. bool PropagatePrecision(HloModule* module) { @@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase { inst->users()[0]->opcode() == HloOpcode::kConvert && inst->users()[0]->shape().element_type() == BF16; } + + std::unique_ptr CreateDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + DefaultPrecisionConfig(2)); + } }; // Tests that BF16 can propagate through select over non-tuple buffers, but not @@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); HloInstruction* add1 = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); - HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b)); HloInstruction* sel = builder.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a)); + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(xpose)); @@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a))); HloInstruction* b = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b))); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(dot->operand(0))); @@ -150,10 +163,10 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)), dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)), + LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)), dot->operand(1)->literal())); } @@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple0->shape(), tuple1, 0)), 0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); HloInstruction* output_tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); @@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), output_tuple); EXPECT_TRUE(OutputsBF16(xpose)); @@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(add1)); @@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { // Tests that a non-fusion computation's root should not be changed. TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* a = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); @@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add)); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); @@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), tuple); EXPECT_FALSE(OutputsBF16(add)); @@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); HloInstruction* b_f1 = builder_f1.AddInstruction( HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); - HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1)); auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion1); EXPECT_TRUE(OutputsBF16(add)); @@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); HloInstruction* param = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param")); @@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); HloInstruction* add_f = builder_f.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); - HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + HloInstruction* dot_f = + builder_f.AddInstruction(CreateDot(shape, add_f, add_f)); auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), fusion); } @@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { HloInstruction::CreateGetTupleElement(shape, fusion, 0)); HloInstruction* gte1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, fusion, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(gte0)); @@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { HloInstruction* xpose = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + HloInstruction* dot = builder.AddInstruction( + CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add0)); @@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { auto builder_cond = HloComputation::Builder("cond"); auto cond_param = builder_cond.AddInstruction( HloInstruction::CreateParameter(0, shape, "cond_param")); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param)); auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); auto body_param = builder_body.AddInstruction( HloInstruction::CreateParameter(0, shape, "body_param")); - auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, body_param, body_param)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_param, body_param)); auto body = module->AddEmbeddedComputation(builder_body.Build()); auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE( @@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest, HloInstruction::CreateParameter(0, shape, "cond_param")); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1}, + {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest, auto while_hlo = builder.AddInstruction( HloInstruction::CreateWhile(shape, cond, body, add)); - auto dot = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_FALSE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_FALSE(OutputsBF16(add)); EXPECT_FALSE(OutputsBF16(body_fusion)); @@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { // This add should prevent RHS from using BF16 auto cond_add_rhs = builder_cond.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs)); - auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond_lhs, cond_add_rhs)); + auto cond_dot = + builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs)); builder_cond.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond = module->AddEmbeddedComputation(builder_cond.Build()); auto builder_body = HloComputation::Builder("body"); @@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot1 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); - auto body_dot2 = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs)); + auto body_dot1 = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); + auto body_dot2 = + builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs)); auto body_transpose = builder_body.AddInstruction( HloInstruction::CreateTranspose(shape, body_dot2, {0, 1})); builder_body.AddInstruction( @@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { HloInstruction::CreateGetTupleElement(shape, while_hlo, 0)); auto rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, while_hlo, 1)); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), dot); EXPECT_TRUE(OutputsBF16(lhs)); @@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond0_add_rhs = builder_cond0.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs)); - auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs)); + auto cond0_dot = + builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs)); builder_cond0.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond0.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond0_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond0.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond0.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build()); // Condition computation for the second while. @@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto cond1_add_lhs = builder_cond1.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs)); - auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs)); + auto cond1_dot = + builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs)); builder_cond1.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})), - builder_cond1.AddInstruction(HloInstruction::CreateSlice( - ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1})))); + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction( + HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}), + cond1_dot, {0, 0}, {1, 1}, {1, 1})))), + builder_cond1.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {}), + builder_cond1.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2}, + {1, 1})))))); auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build()); // Body computation shared by both whiles. @@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { HloInstruction::CreateGetTupleElement(shape, body_param, 0)); auto body_rhs = builder_body.AddInstruction( HloInstruction::CreateGetTupleElement(shape, body_param, 1)); - auto body_dot = builder_body.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs)); + auto body_dot = + builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs)); builder_body.AddInstruction( HloInstruction::CreateTuple({body_dot, body_rhs})); auto body = module->AddEmbeddedComputation(builder_body.Build()); @@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1)); - auto lhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while0, 1)))); - auto rhs = builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 0)), - builder.AddInstruction( - HloInstruction::CreateGetTupleElement(shape, while1, 1)))); - auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs)); + auto lhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while0, 1)))); + auto rhs = builder.AddInstruction( + CreateDot(shape, + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 0)), + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, while1, 1)))); + auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs)); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_FALSE(OutputsBF16(body_dot)); EXPECT_FALSE(OutputsBF16(body_rhs)); EXPECT_FALSE(OutputsBF16(body_lhs)); @@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), add2); EXPECT_EQ(add2->operand(0), add0); @@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { HloInstruction::CreateGetTupleElement(shape, domain, 0)); HloInstruction* b_gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(shape, domain, 1)); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); // test BF16 propagated through domain @@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); HloInstruction* b_trans = builder.AddInstruction( HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* dot = + builder.AddInstruction(CreateDot(shape, a_trans, b_trans)); HloInstruction* root = builder.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(PropagatePrecision(module.get())); + EXPECT_TRUE(PropagatePrecision(module)); EXPECT_EQ(computation->root_instruction(), root); EXPECT_TRUE(OutputsBF16(a_trans)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8b8c6bfd269971efa6fcd186e4825e6f13bb4094..65fa951afe3e60652413206913640af38f5bb824 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -617,18 +616,24 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -1064,7 +1069,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(&assignment->module()); FlatSet all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1072,7 +1077,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1090,7 +1095,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique( absl::make_unique(alignment)), - assignment->module(), module_sequence, + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1121,7 +1126,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( absl::make_unique( absl::make_unique(alignment)), - *computation, *instruction_sequence, + *computation, HloInstructionSequence(*instruction_sequence), assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8bd1533972413194dec3609829c8cf8df570cc2a..795beb9ff5ceb2998a85fbd03d8bb1d3b2febc12 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -30,16 +30,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase { HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - absl::make_unique(module, - module_sequence), + module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1247,9 +1245,10 @@ TEST_F(BufferAssignmentTest, TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output // is properly handled. auto builder = HloComputation::Builder(TestName()); + Literal elements[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1490,10 +1489,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_ab = builder.AddInstruction( - HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); - auto dot_bc = builder.AddInstruction( - HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( + shape_2x4, param_a, param_b, dot_dnums, precision_config)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( + shape_3x4, param_b, param_c, dot_dnums, precision_config)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); @@ -1782,11 +1784,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, - absl::make_unique(module, sequence), + module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2093,17 +2094,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2260,29 +2269,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector& sequence) { - tensorflow::gtl::FlatSet seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2337,27 +2323,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 26e26e316d6281a97f8317f8ed1d7a6f21b0d374..17e50905059ad2c92784d14132c1cb1f46f35ade 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -440,15 +440,15 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + Literal elements0[] = {LiteralUtil::CreateR0(0), + LiteralUtil::CreateR0(1)}; + auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]}); + Literal element1 = LiteralUtil::CreateR0(3); + auto inner_tuple1 = LiteralUtil::MakeTuple({&element1}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( - inner_tuple0->shape(), tuple_constant, 0)); + inner_tuple0.shape(), tuple_constant, 0)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index cc80b7484313329104eec1ce71a150b47d8330c9..34f3f914d593bc603c4964663f9cafb70a136fd3 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,7 +31,7 @@ namespace { using ::testing::UnorderedElementsAre; -class CallGraphTest : public HloTestBase { +class CallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation( @@ -96,7 +96,7 @@ TEST_F(CallGraphTest, SingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(1, call_graph->nodes().size()); EXPECT_TRUE(call_graph->IsFlattened()); @@ -118,7 +118,7 @@ TEST_F(CallGraphTest, UnreachableComputation) { HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -140,7 +140,7 @@ TEST_F(CallGraphTest, ParallelComputation) { HloComputation* entry_computation = module->AddEntryComputation( MakeMappingComputation(map_computation, /*callsites=*/5)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); @@ -169,7 +169,7 @@ TEST_F(CallGraphTest, SequentialComputations) { HloComputation* entry_computation = module->AddEntryComputation( MakeCallingComputation(called_computation, /*callsites=*/3)); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); // The called computation is only called from one other computation, but there @@ -210,7 +210,7 @@ TEST_F(CallGraphTest, ContextBothComputations) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(2, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -259,7 +259,7 @@ TEST_F(CallGraphTest, ComputationWithConditional) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(3, call_graph->nodes().size()); @@ -328,7 +328,7 @@ TEST_F(CallGraphTest, ComplexGraph) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); EXPECT_FALSE(call_graph->IsFlattened()); @@ -452,7 +452,7 @@ TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { entry_computation = module->AddEntryComputation(builder.Build()); } - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(5, call_graph->nodes().size()); // Verify NearestAncestorsInSameComputation for various instructions in the @@ -482,7 +482,7 @@ TEST_F(CallGraphTest, VisitSingletonComputation) { auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); std::vector visited; TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) { @@ -499,7 +499,7 @@ TEST_F(CallGraphTest, VisitUnreachableComputation) { module->AddEntryComputation(MakeScalarComputation()); HloComputation* unreachable_computation = module->AddEmbeddedComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // Test visitation of only reachable nodes. { @@ -533,7 +533,7 @@ TEST_F(CallGraphTest, VisitWithError) { // Test that the call graph visitor properly propagates errors. auto module = CreateNewModule(); module->AddEntryComputation(MakeScalarComputation()); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); Status status = call_graph->VisitNodes( [](const CallGraphNode&) { return InternalError("Visitation failed"); }); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 5d85a3f173d50a964420e720f5c9b416731d948c..e6b566543594a86eb5369ee9b7440f62618f6c5a 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), @@ -91,6 +91,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { module->AddEmbeddedComputation(just_false.Build()); HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 9c81a86bbb9dc7078237fe200f510a4905cb4d8d..0ac4a65ec6ae55fabd2b48ea2982b94f9551c8d2 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -214,8 +214,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(expanded_filter_shape.element_type())))); + auto zero = add(HloInstruction::CreateConstant( + LiteralUtil::Zero(expanded_filter_shape.element_type()))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); auto new_filter = add( @@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { filter_mask, expanded_filter, zero_filter)); auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - convolution->window(), dim_numbers, /*feature_group_count=*/1); - new_convolution->set_precision_config(convolution->precision_config()); + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d412578619e5d23db3933af19d665cf8beb4d622..8cc522a59e9805ec86e9e69c8d6e5fa1a3ab682d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -122,7 +122,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", - "//tensorflow/compiler/xla/service:hlo_scheduling", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", @@ -670,6 +670,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/strings", @@ -800,6 +801,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -821,6 +823,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -945,6 +948,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -970,6 +974,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 098ce17a568fd3fb531020e7731100fabda43721..2d9978404cc9ec1e40fc61aaf794a8f1f06050bb 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,9 +130,9 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); - new_conv->set_precision_config(hlo->precision_config()); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696da5cfdde3dece03250ae5fa51c92f25..2083f440fdd971db1b675d005664d25e6de53dbe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,7 +32,7 @@ namespace cpu { using ::testing::ElementsAre; -class ConvCanonicalizationTest : public HloTestBase { +class ConvCanonicalizationTest : public HloVerifiedTestBase { public: ConvCanonicalizationTest() { for (int i = 0; i < 2; ++i) { @@ -84,7 +84,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -95,7 +96,7 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode()); @@ -146,7 +147,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -156,7 +158,7 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); ConvCanonicalization conv_canonicalization(&target_machine_features); - EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie()); } } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 796f36510e414cde692208cfe0cf9626acae63d3..18fc144efe0023c0893adfcb16eda3341c0938d3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -77,12 +77,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" @@ -584,16 +584,14 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - absl::make_unique( - module.get(), module_sequence), + absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -627,9 +625,10 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -637,9 +636,10 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module, - absl::make_unique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc index 4db7fa446ea9188940f930bcadf753bd3e6b79e3..c9fb34be1cd582c71618c770c892058c233c571a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -52,7 +52,7 @@ int64 CountCopies(const HloModule& module) { return count; } -class CpuCopyInsertionTest : public HloTestBase { +class CpuCopyInsertionTest : public HloVerifiedTestBase { protected: void InsertCopies(HloModule* module) { CpuCopyInsertion copy_insertion; @@ -90,7 +90,7 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*module), 3); @@ -127,7 +127,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) { module->AddEntryComputation(builder.Build()); - InsertCopies(module.get()); + InsertCopies(module); EXPECT_EQ(CountCopies(*subcomputation), 2); EXPECT_THAT(subcomputation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc index 0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2..be1208fb2df2a1a11a093810b5f6c2a83f468062 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class CpuHloSupportCheckerTest : public HloTestBase { +class CpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: CpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(CpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("CPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 284929ca073ca0d8c5c7cc383f8341a53d0f9e88..7d99b914d4f5e5d27722bcd098d2ae0c54a36a23 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace op = xla::testing::opcode_matchers; @@ -38,7 +39,11 @@ std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { @@ -692,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name, auto* addend = builder.AddInstruction( HloInstruction::CreateParameter(2, dot_shape, "param2")); - auto* dot = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + auto* dot = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); builder.AddInstruction( HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 9363af3b8941c68284915d6770188bde4c87f78e..4668f3872dad598edf4c7680e1b601622104ab3e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); auto dot_a_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); auto dot_b_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); + CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); auto dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); + CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -276,8 +276,8 @@ static StatusOr RunDotOutputFusion( HloInstruction::CreateParameter(1, dot_shape, "param1")); HloInstruction* dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); - HloInstruction* dot_result = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* dot_result = + builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); HloInstruction* add_result; if (dot_operand_idx_in_add == 0) { add_result = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e5cf15c686157d837901fa912bdde2a7a5d501d9..df8c2a636bbda52e3a8df00015ce3f27e6ba1aea 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -110,7 +110,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 58a333b8fb2dc46868b04fec0d7d87788a809d06..3df99464ba1103488b9fe054593740ada108d3da 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index a84ee78b19981e480858320e445de7f5dae27d61..fad76338a57cd9eb21d9469ca8552efa8ea0129b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 942e2ddd3940fffd5d87518f059beaced3cdc925..55d5925642a97b1a0425c092c82070d4b8e59df3 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -37,21 +37,20 @@ int main(int argc, char** argv) { xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie()); // Transfer parameters. - std::unique_ptr param0_literal = + xla::Literal param0_literal = xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + xla::Literal param1_literal = xla::LiteralUtil::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = - client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client->TransferToServer(param1_literal).ConsumeValueOrDie(); // Build computation. xla::XlaBuilder builder(""); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p1, p0, {0}); xla::StatusOr computation_status = builder.Build(); @@ -59,17 +58,16 @@ int main(int argc, char** argv) { // Execute and transfer result of computation. xla::ExecutionProfile profile; - xla::StatusOr> result = - client->ExecuteAndTransfer( - computation, - /*arguments=*/{param0_data.get(), param1_data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/&profile); - std::unique_ptr actual = result.ConsumeValueOrDie(); + xla::StatusOr result = client->ExecuteAndTransfer( + computation, + /*arguments=*/{param0_data.get(), param1_data.get()}, + /*execution_options=*/nullptr, + /*execution_profile=*/&profile); + xla::Literal actual = result.ConsumeValueOrDie(); LOG(INFO) << absl::StrFormat("computation took %dns", profile.compute_time_ns()); - LOG(INFO) << actual->ToString(); + LOG(INFO) << actual.ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc index 7d8e51f909e3db699b745f94a6c625407bc4a6e3..1a3d82de954318368d61e3feeb0345dc592dcd8b 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace cpu { namespace { -class ShapePartitionAssignerTest : public HloTestBase { +class ShapePartitionAssignerTest : public HloVerifiedTestBase { protected: typedef std::vector Vec; @@ -91,7 +91,7 @@ TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { expected_partitions); } -class ShapePartitionIteratorTest : public HloTestBase { +class ShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; }; @@ -145,7 +145,7 @@ TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { } } -class RandomShapePartitionIteratorTest : public HloTestBase { +class RandomShapePartitionIteratorTest : public HloVerifiedTestBase { protected: typedef std::vector> Partition; RandomShapePartitionIteratorTest() diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 2384166fd2002a67a8aa785ad5fb341d037ee01f..c55206eee7ae3c6e4410c59aebf529de98fd2de8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -48,6 +48,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -121,6 +122,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index fcd87b36b32915773546c211d7d2c447a69bef49..18ee25ba9158c28baaf01492c290638b9673f1ec 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { HloInstruction* rhs = builder.AddInstruction( HloInstruction::CreateParameter(1, param_shape, "input")); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } @@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { HloInstruction* lhs_transposed = builder.AddInstruction( HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); - builder.AddInstruction( - HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs)); CompileAndCheck(builder.Build(), spec.filecheck_lines); } diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index 22721051e54e2cf9590b60333c51d1d028bb28e9..1deb412064b02988a8d4a6d726969c948d354d47 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -34,7 +34,7 @@ namespace xla { namespace cpu { namespace { -class CpuFusionTest : public HloTestBase { +class CpuFusionTest : public HloVerifiedTestBase { protected: CpuFusionTest() {} @@ -45,7 +45,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { auto builder = HloComputation::Builder(TestName()); auto input_literal1 = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); auto input_literal2 = LiteralUtil::CreateR1({-2.0, -42.0, 2.0}); - Shape vshape = input_literal1->shape(); + Shape vshape = input_literal1.shape(); auto input1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal1))); @@ -61,7 +61,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -75,16 +75,16 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, result, error_spec_); } TEST_F(CpuFusionTest, FuseElementwiseOpChain) { auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -108,7 +108,7 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -122,11 +122,10 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { EXPECT_EQ(8, fusion_instruction->fused_instruction_count()); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. - LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, - error_spec_); + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, result, error_spec_); } TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { @@ -135,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({-1.5, -2.5, -3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto input = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -184,7 +183,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The computation root instruction was fused. Verify the fusion instruction // is now the root. @@ -209,11 +208,11 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { << fusion_instruction2->fused_instructions_computation()->ToString(); // Compile and execute the computation. - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); // Check the output correctness. LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, - *result, error_spec_); + result, error_spec_); } TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { @@ -232,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // each fusion instruction to ensure that negate is not duplicated. auto builder = HloComputation::Builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - Shape vshape = input_literal->shape(); + Shape vshape = input_literal.shape(); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); @@ -256,7 +255,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { // Run fusion. CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); auto fusion1 = result->operand(0); auto fusion2 = result->operand(1); @@ -315,7 +314,7 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { module->AddEntryComputation(builder.Build()); CpuInstructionFusion fusion; - EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(fusion.Run(module).ValueOrDie()); // The only fusion instruction should be operand 0 of the tuple (formerly // negate1). diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc index c35569c6619ba5b534c5d8bb7ad683d84b6ecf4b..5cc6d01c0f15d4209cbc1fb259a0078fb9957f6e 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests Infeed operation used in a while loop, as in the code below. The @@ -157,21 +157,21 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { // Send 5 Infeed data of shape F32[3]. ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({1, 2, 3}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({1, 2, 3}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({4, 5, 6}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({4, 5, 6}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({7, 8, 9}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({7, 8, 9}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({10, 11, 12}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({10, 11, 12}))); ASSERT_IS_OK( - client_->TransferToInfeed(*LiteralUtil::CreateR1({13, 14, 15}))); + client_->TransferToInfeed(LiteralUtil::CreateR1({13, 14, 15}))); delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 3 infeed data should be added. - LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(45.0f, result_literal, ErrorSpec{1e-7}); } // Tests two Infeed operations with a total order. The order is enforced by @@ -250,17 +250,17 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({3, 4}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({3, 4}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({5, 6}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8}), + LiteralUtil::CreateR0(false)}))); // Asynchronously launch the execution on the device. std::unique_ptr result; @@ -275,21 +275,21 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). sleep(1); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(true)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({7, 8, 9}).get(), - LiteralUtil::CreateR0(false).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({7, 8, 9}), + LiteralUtil::CreateR0(false)}))); ASSERT_IS_OK(client_->TransferToInfeed( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({4, 5, 6}).get(), - LiteralUtil::CreateR0(true).get()}))); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({4, 5, 6}), + LiteralUtil::CreateR0(true)}))); // Wait for the execution to be done, and transfer the result. delete computation_thread; // Joins the thread. auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); // Only the first 6 infeed data should be added. - LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); + LiteralTestUtil::ExpectR0Near(66.0f, result_literal, ErrorSpec{1e-7}); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index bb105194f1c9001ca4d9fff9174e1ea7e5d8b72a..7af51db55af44ae1e437ea8e4de7427012cad82f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -41,8 +41,7 @@ class CpuNoAliasTest : public CpuCodegenTest {}; TEST_F(CpuNoAliasTest, Concat) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); HloInstruction* param_x = builder.AddInstruction( HloInstruction::CreateParameter(0, param_shape, "x")); diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 09cb10d6ee579111b6e0cdb460b9af2b95d090db..b2ba2617902104bfea06713332fa1c2aedea536d 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( - dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); - dot_r2->set_precision_config(dot->precision_config()); + auto dot_r2 = computation->AddInstruction( + HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, + dot_dnums, dot->precision_config())); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index 1b3be199f632a2aa6bd2c5a3820c7c5ce9b1382e..852f34e06df35242b13110ae4411b8c969c26019 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -56,9 +56,9 @@ ENTRY main { } )"; - std::unique_ptr lhs = LiteralUtil::CreateR3({{{1}, {2}}}); - std::unique_ptr rhs = LiteralUtil::CreateR3({{{3}, {4}}}); - RunTest(hlo_text, {lhs.get(), rhs.get()}); + Literal lhs = LiteralUtil::CreateR3({{{1}, {2}}}); + Literal rhs = LiteralUtil::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {&lhs, &rhs}); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index 8f6608241ed02bbb7e9fde9b6d767c002435e777..5fbd73a5363b4cdbcaafedbe6f4e7bd6bb2a92d8 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloVerifiedTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { @@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) { } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr flat_call_graph = CallGraph::Build(module.get()); + std::unique_ptr flat_call_graph = CallGraph::Build(module); const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); } @@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { } { - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(2, cond_node.caller_callsites().size()); } { - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); const CallGraphNode& cond_node = call_graph->GetNode(cond_computation); EXPECT_EQ(1, cond_node.caller_callsites().size()); } @@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { module->AddEntryComputation( MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry")); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); @@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, module->computation_count()); - TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module)); EXPECT_TRUE(result); - std::unique_ptr call_graph = CallGraph::Build(module.get()); + std::unique_ptr call_graph = CallGraph::Build(module); // The true and false computations must now be different. EXPECT_EQ(3, module->computation_count()); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 4ed91ef18768d09c252d1b73890637227f0ce717..bec02e14f951c6d905b7329be5c02896984279d0 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -125,7 +125,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( device_memory.size()); // Element is array-shaped: transfer array data to device buffer. const auto subliteral = LiteralSlice(literal, index); - std::unique_ptr relayed_out_literal; + Literal relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { @@ -138,7 +138,7 @@ Status GenericTransferManager::TransferLiteralToDeviceAsync( // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->untyped_data(); + source = relayed_out_literal.untyped_data(); TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, /*size=*/GetByteSizeRequirement(device_subshape), source, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a68b7a1bef81e369dc1bbcd249642e5b80401c64..64b96836280718f13ac5ee9f4a497ed54a273b19 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -108,6 +108,8 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", @@ -172,6 +174,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service:while_loop_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", @@ -369,6 +372,8 @@ cc_library( srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ + ":backend_configs", + ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -394,6 +399,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -480,6 +486,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) @@ -811,9 +818,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_memory_scheduler", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], ) @@ -830,6 +837,8 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", @@ -898,6 +907,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 05448d863dd2cfe69ad70168be40cdea5bc7017f..3a23ac1d634161628b2bd2589d0260022868ba36 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" @@ -30,62 +31,32 @@ namespace gpu { using se::dnn::AlgorithmDesc; -ConvolutionThunk::ConvolutionThunk( - CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count, - int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo) - : Thunk(Kind::kConvolution, hlo), - convolution_kind_(convolution_kind), - input_buffer_(input_buffer), - filter_buffer_(filter_buffer), - output_buffer_(output_buffer), - tuple_result_buffer_(tuple_result_buffer), - scratch_buffer_(scratch_buffer), - input_shape_(input_shape), - filter_shape_(filter_shape), - output_shape_(output_shape), - window_(window), - dim_nums_(dim_nums), - feature_group_count_(feature_group_count), - algorithm_(algorithm), - tensor_ops_enabled_(tensor_ops_enabled) {} - Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - se::DeviceMemoryBase input_data = - buffer_allocations.GetDeviceAddress(input_buffer_); - se::DeviceMemoryBase filter_data = - buffer_allocations.GetDeviceAddress(filter_buffer_); - se::DeviceMemoryBase output_data = - buffer_allocations.GetDeviceAddress(output_buffer_); + CudnnConvParams params; + + params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); + params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); + params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution( - convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, - filter_data, output_data, scratch, window_, dim_nums_, - feature_group_count_, algorithm_config, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); // Figure out which of output/input/filter is the result produced by // this op, and write the result tuple. void* result_ptr = [&] { - switch (convolution_kind_) { + switch (params.kind) { case CudnnConvKind::kForward: - return output_data.opaque(); + return params.output_buf.opaque(); case CudnnConvKind::kBackwardInput: - return input_data.opaque(); + return params.input_buf.opaque(); case CudnnConvKind::kBackwardFilter: - return filter_data.opaque(); + return params.filter_buf.opaque(); } }(); void* ptrs[] = {result_ptr, scratch.opaque()}; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 68d67c40c56145a137398540e90b75b33642589f..d7d1f91fba7239ed1670119f5df623d025c1d368 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -32,7 +33,7 @@ limitations under the License. namespace xla { namespace gpu { -// This class stores everything that StreamExecutor needs to launch a BNN +// This class stores everything that StreamExecutor needs to launch a DNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. @@ -41,27 +42,24 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that - // we should use the default (i.e. baseline) cudnn algorithm. - // // Note that "output" here doesn't refer to the output from running this // thunk, but rather to the "output" of a hypothetical forward convolution // that corresponds to this input+filter+output triple. That is, the result // generated by this thunk is "output" for forward convs, "input" for // backward-input convs, and "filter" for backward-filter convs. - // - // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(CudnnConvKind convolution_kind, - const BufferAllocation::Slice& input_buffer, - const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, - const BufferAllocation::Slice& tuple_result_buffer, - const BufferAllocation::Slice& scratch_buffer, - const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, - int64 feature_group_count, int64 algorithm, - bool tensor_ops_enabled, const HloInstruction* hlo); + ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, + BufferAllocation::Slice input_slice, + BufferAllocation::Slice filter_slice, + BufferAllocation::Slice output_slice, + BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + input_buffer_(std::move(input_slice)), + filter_buffer_(std::move(filter_slice)), + output_buffer_(std::move(output_slice)), + scratch_buffer_(std::move(scratch_slice)), + tuple_result_buffer_(std::move(tuple_result_slice)) {} ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -72,23 +70,12 @@ class ConvolutionThunk : public Thunk { HloExecutionProfiler* profiler) override; private: - const CudnnConvKind convolution_kind_; - - const BufferAllocation::Slice input_buffer_; - const BufferAllocation::Slice filter_buffer_; - const BufferAllocation::Slice output_buffer_; - const BufferAllocation::Slice tuple_result_buffer_; - const BufferAllocation::Slice scratch_buffer_; - - const Shape input_shape_; - const Shape filter_shape_; - const Shape output_shape_; - - const Window window_; - const ConvolutionDimensionNumbers dim_nums_; - int64 feature_group_count_; - int64 algorithm_; - bool tensor_ops_enabled_; + const HloCustomCallInstruction* cudnn_call_; + BufferAllocation::Slice input_buffer_; + BufferAllocation::Slice filter_buffer_; + BufferAllocation::Slice output_buffer_; + BufferAllocation::Slice scratch_buffer_; + BufferAllocation::Slice tuple_result_buffer_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5c2555148ae5de4a15e5a5f003b4783c64a20e9c..c607aea1a8c74057444467cecd7087f967bc7ee4 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/mutex.h" @@ -176,10 +177,14 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // caching would speed up compilation a lot. StatusOr> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr) { + const HloCustomCallInstruction* instr) { + CudnnConvParams params; + TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); + + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); CHECK_EQ(input_shape.element_type(), output_shape.element_type()); // TODO(timshen): for now only check fp16. It can be expanded to other types, @@ -220,13 +225,13 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + TF_ASSIGN_OR_RETURN(params.input_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + TF_ASSIGN_OR_RETURN(params.filter_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + TF_ASSIGN_OR_RETURN(params.output_buf, input_output_allocator.AllocateBytes( &stream, ShapeUtil::ByteSizeOf(output_shape))); @@ -253,32 +258,32 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( static_cast(buffer.opaque()) + aligned_size, left_over_bytes); stream.ThenMemcpy(&left_over, halfs, left_over_bytes); }; - initialize_f16(input_buf); - initialize_f16(filter_buf); - initialize_f16(output_buf); + initialize_f16(params.input_buf); + initialize_f16(params.filter_buf); + initialize_f16(params.output_buf); } else { // Although we don't have evidence this matters, zero out the buffers before // autotuning. It's conceivable that using uninitialized memory as the // inputs might affect performance if e.g. the inputs contain denormals, and // this is easy enough. - stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()); + stream.ThenMemZero(¶ms.input_buf, params.input_buf.size()) + .ThenMemZero(¶ms.filter_buf, params.filter_buf.size()) + .ThenMemZero(¶ms.output_buf, params.output_buf.size()); } DeviceMemoryBase* result_buf = [&] { - switch (kind) { + switch (params.kind) { case CudnnConvKind::kBackwardFilter: - return &filter_buf; + return ¶ms.filter_buf; case CudnnConvKind::kBackwardInput: - return &input_buf; + return ¶ms.input_buf; case CudnnConvKind::kForward: - return &output_buf; + return ¶ms.output_buf; } }(); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, dnums, stream_exec_); + input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; @@ -288,18 +293,16 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // this algorithm considered correct, though. optional first_algorithm; for (const AlgorithmDesc& alg : - GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, - filter_buf, output_buf, &scratch_allocator, window, dnums, - feature_group_count, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + params.algorithm = AlgorithmConfig(alg); + bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, + &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { const bool crash_on_checking_failure = @@ -374,34 +377,8 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - const auto& call_target = instr->custom_call_target(); - const auto& lhs_shape = instr->operand(0)->shape(); - const auto& rhs_shape = instr->operand(1)->shape(); - const auto& conv_result_shape = instr->shape().tuple_shapes(0); - StatusOr> alg_scratch_and_tc; - if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = - PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, instr->window(), - instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), - instr->convolution_dimension_numbers(), instr->feature_group_count(), - instr); - } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, - instr->window(), instr->convolution_dimension_numbers(), - instr->feature_group_count(), instr); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instr->ToString(); - } + StatusOr> alg_scratch_and_tc = + PickBestAlgorithm(Cast(instr)); if (!alg_scratch_and_tc.ok()) { LOG(ERROR) << alg_scratch_and_tc.status(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 0cb01161b023b900c8c4b1386b679fe2bd5db802..f79b113f8fac0190adef9a8d68d1617710b1402c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -49,10 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); StatusOr> PickBestAlgorithm( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - HloInstruction* instr); + const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 9bf721ecd2ad938e71f88a6fc65cd2d3bd25161e..228379a2488a8564564e8b5e35a863553f4bbac2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include #include #include @@ -59,8 +60,6 @@ std::tuple MatchBackwardFilter( HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward filter. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -218,13 +217,16 @@ std::tuple MatchBackwardFilter( // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardInput( - HloInstruction* conv) { +std::tuple +MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/31709653): Figure out if we can use grouped convolutions also on - // backward input. + // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but at least for now with version 7.1.4 + // it is slower. This needs to be re-evaluated for future cuDNN versions. + // Note that we already have the necessary code down below, the only thing to + // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { return no_match_result; } @@ -232,51 +234,38 @@ std::tuple MatchBackwardInput( // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); HloInstruction* reverse_filter = conv->mutable_operand(1); - - // Match the reverse of the filter. ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); - const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions(); - if (reverse_filter->opcode() == HloOpcode::kReverse) { - if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() || - !std::is_permutation(kernel_spatial_dims.begin(), - kernel_spatial_dims.end(), - reverse_filter->dimensions().begin())) { - VLOG(1) - << "Backward input convolution should reverse all kernel dimensions."; - return no_match_result; - } - } else if (reverse_filter->IsConstant()) { - // If the filter is a constant, we're willing to pattern-match to a - // backwards-input conv, on the theory that - // - // a) reversing a constant is free, and - // b) even if the user specified this filter as reverse(constant), we would - // long ago have constant-folded away the reverse. - // - // If the constant has any other uses, reversing it isn't entirely free, - // since we'd now have two constants to keep in memory. But hopefully it's - // free enough. - // - // TODO(jlebar): Should we do this even if the filter is not a constant? - // Reversing a non-constant filter is probably cheaper than padding the - // input! - - // Nothing to do, just fall through. - } else { - // Possibly 1x1 filter. - for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { - if (conv->window().dimensions(i).size() != 1) { - VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: " - << reverse_filter->ToString(); - return no_match_result; - } - } - if (!window_util::HasBaseDilation(conv->window())) { - VLOG(1) << conv->ToString() - << " is a regular forward convolution. No need " - "to fold it to a backward input convolution."; - return no_match_result; - } + + // We pattern-match to a backwards input conv if: + // + // - all spatial dims of the filter are reversed + // + // OR + // + // - filter is 1x1 or a constant AND + // - conv has base dilation (otherwise this is just a regular forward conv). + // + // The final criterion above is just for canonicalization; cudnn seems to run + // just as fast if we canonicalize 1x1/constant filters without base dilation + // to forward or backward convs. We canonicalize to forward conv because (a) + // it's more natural (constant filters usually show up when doing inference, + // and having backwards convolutions in inference graphs would be weird), and + // (b) cudnn has special fusions for forward conv plus bias and activation, + // and we want to pattern-match to that after running this pass. + bool is_reversed_filter = + reverse_filter->opcode() == HloOpcode::kReverse && + absl::c_is_permutation(dnums.kernel_spatial_dimensions(), + reverse_filter->dimensions()); + bool is_1x1_filter = + absl::c_all_of(conv->window().dimensions(), + [](const WindowDimension& d) { return d.size() == 1; }); + if (!is_reversed_filter && + !(window_util::HasBaseDilation(conv->window()) && + (reverse_filter->IsConstant() || is_1x1_filter))) { + VLOG(1) << "Can't match to backwards convolution. Either filter is not " + "kReverse, or it's not a base-dilated conv with a 1x1 or " + "constant filter."; + return no_match_result; } // Match padding and dilation of the forward convolution. @@ -401,26 +390,64 @@ std::tuple MatchBackwardInput( } } - // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. - // This simplifies things for our caller, and algebraic-simplifier will later - // remove any unnecessary reverses. - if (reverse_filter->opcode() != HloOpcode::kReverse) { + // OK, it's a match! Switch the input feature dimension with the output + // feature dimension. This is the way cuDNN expects it to be. + dnums.set_kernel_input_feature_dimension( + conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + dnums.set_kernel_output_feature_dimension( + conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + + // If we matched against a constant, we need to add a reverse op that can be + // subsumed by the cuDNN call. algebraic-simplifier will later remove any + // unnecessary reverses. + if (reverse_filter->opcode() != HloOpcode::kReverse && + reverse_filter->IsConstant()) { // Create a double-reverse, which is a nop. HloComputation* c = conv->parent(); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); - reverse_filter = c->AddInstruction( - HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, - AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); + reverse_filter = c->AddInstruction(HloInstruction::CreateReverse( + reverse_filter->shape(), reverse_filter, + AsInt64Slice(dnums.kernel_spatial_dimensions()))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } - dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); - dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); + // Calculate the 'rhs' that goes into the backward input convolution. + HloInstruction* rhs = reverse_filter; + // One reverse is subsumed by the cuDNN call. + if (rhs->opcode() == HloOpcode::kReverse) { + rhs = rhs->mutable_operand(0); + } + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, new_window, dnums, rhs); + } + + // Handle grouped convolutions. Because we swapped the input feature dimension + // with the output feature dimension, we need to also reshape the kernel so + // that the 'feature_group_count' parameter still makes sense. The + // 'feature_group_count' parameter essentially specifies how often the + // 'kernel_input_feature_dimension' is repeated. So when we swap these + // dimensions, we need to divide the new 'kernel_input_feature_dimension' by + // 'feature_group_count' and multiply the new + // 'kernel_output_feature_dimension' by 'feature_group_count'. + Shape new_shape = rhs->shape(); + int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); + int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + + // In the backward convolution case, the spatial dimensions become the + // feature dimensions, and we are guaranteed that the spatial dimensions are + // adjacent. + CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); + int64 input_features = new_shape.dimensions(input_feature_dimension); + int64 output_features = new_shape.dimensions(output_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_features / conv->feature_group_count()); + new_shape.set_dimensions(output_feature_dimension, + output_features * conv->feature_group_count()); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); + return std::make_tuple(true, new_window, dnums, rhs); } // Tries to rewrite a single convolution into a call to cudnn. @@ -431,6 +458,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { bool match; Window window; ConvolutionDimensionNumbers dnums; + HloInstruction* rhs; std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { @@ -439,13 +467,8 @@ StatusOr RunOnInstruction(HloInstruction* conv) { window, dnums, conv->feature_group_count()); } - std::tie(match, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - // Backward input conv subsumes the conv plus the reverse in operand 1. - HloInstruction* reverse = conv->mutable_operand(1); - CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); - HloInstruction* rhs = reverse->mutable_operand(0); - return CreateCudnnConvBackwardInput(conv->shape(), conv->mutable_operand(0), rhs, window, dnums, conv->feature_group_count()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 46c23db4652cccb06c9ca2a199a46ae04b332286..d237f8930b74d460ad3d4602670a5afb19b496a2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -107,12 +107,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -135,12 +135,12 @@ TEST_F(CudnnConvolutionRewriterTest, Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(3); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -170,7 +170,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -200,7 +201,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -228,7 +230,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -272,13 +275,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -319,11 +323,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - conv_window, + /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, conv_window, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -350,12 +354,13 @@ TEST_F(CudnnConvolutionRewriterTest, 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - default_conv_window_, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -402,13 +407,15 @@ TEST_F(CudnnConvolutionRewriterTest, } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -449,13 +456,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -502,13 +511,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); const HloComputation* entry_computation = @@ -554,13 +565,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -577,7 +590,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); string constant_str = - LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + LiteralUtil::CreateR4FromArray4D(constant_arr).ToString(); ParseAndVerifyModule(absl::StrFormat(R"( HloModule test diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 05125e9d1fb3cd03cb72b7854fc28c767b49fd64..2a86ac265e4d6a6502162ac33b04b0ee362ce49e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -72,14 +72,22 @@ class ScratchBufAllocator : public se::ScratchAllocator { }; template -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { +Status RunCudnnConvolutionImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + CudnnConvKind kind = params.kind; + const Shape& input_shape = *params.input_shape; + const Shape& filter_shape = *params.filter_shape; + const Shape& output_shape = *params.output_shape; + DeviceMemory input_buf(params.input_buf); + DeviceMemory filter_buf(params.filter_buf); + DeviceMemory output_buf(params.output_buf); + const Window& window = *params.window; + const ConvolutionDimensionNumbers& dnums = *params.dnums; + int64 feature_group_count = params.feature_group_count; + AlgorithmConfig algorithm = params.algorithm; + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); @@ -219,54 +227,31 @@ string CudnnConvKindToString(CudnnConvKind kind) { } } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, - output_buf, &scratch_allocator, window, dnums, feature_group_count, - algorithm, stream, profile_result); + return RunCudnnConvolution(params, &scratch_allocator, stream, + profile_result); } -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = output_shape.element_type(); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = params.output_shape->element_type(); switch (output_primitive_type) { case F16: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, + stream, profile_result); case F32: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, dnums, - feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); case F64: - return RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf), - se::DeviceMemory(filter_buf), - se::DeviceMemory(output_buf), scratch_allocator, window, - dnums, feature_group_count, algorithm, stream, profile_result); + return RunCudnnConvolutionImpl(params, scratch_allocator, stream, + profile_result); default: - LOG(FATAL) << ShapeUtil::HumanString(output_shape); + LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape); } } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index a1b4fc71d0cac3e5ea067ca7941b07cbade8d7cc..381aa37a1b1405e00d62adf9855e9229482f5b86 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -47,6 +47,20 @@ enum class CudnnConvKind { kBackwardFilter, // input + output => filter }; +struct CudnnConvParams { + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; +}; + // Converts a CudnnConvKind value to a string. string CudnnConvKindToString(CudnnConvKind kind); @@ -55,10 +69,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over float and half, and -// convolution with half data type is implemented with cudnn PSEUDO_HALF -// configuration, that is, the input values are half and the internal -// computation type is float. +// At the moment convolution with half data type is implemented with cudnn +// PSEUDO_HALF configuration, that is, the input values are half and the +// internal computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -70,23 +83,14 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::DeviceMemoryBase scratch_buf, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution( - CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, se::DeviceMemoryBase input_buf, - se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, - se::ScratchAllocator* scratch_allocator, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count, - se::dnn::AlgorithmConfig algorithm, se::Stream* stream, - se::dnn::ProfileResult* profile_result = nullptr); +Status RunCudnnConvolution(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 743035a84eaeb41fafb336844a1a7a07b82af4db..02a0d028c118aba23996f9b97d05443bb4a00c88 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -21,8 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -198,11 +199,12 @@ StatusOr> GpuHloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 30a0e7cecd202e83898d34e00b5b49684d1b1b68..07a7fc67aa555845c3de57e574ab582403ec0490 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -33,7 +33,9 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. class GpuHloSchedule { public: // Constructs an GpuHloSchedule for the given module, based on the given diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc index 0922e44a126eadab17d60d9ece53aae8d8f1c218..b857fa775a76ec999b505a2a64332cc0c54cf00b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -24,13 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class GpuHloScheduleTest : public HloTestBase { +class GpuHloScheduleTest : public HloVerifiedTestBase { protected: using HloVec = std::vector; @@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* add = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc index 0a4089df4c954cafcbe241189ee79a0995683513..27a4d0b601f3807fe6b94dd6171a44f292921ede 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -25,7 +25,7 @@ namespace { using ::testing::HasSubstr; -class GpuHloSupportCheckerTest : public HloTestBase { +class GpuHloSupportCheckerTest : public HloVerifiedTestBase { protected: GpuHloSupportChecker& checker() { return checker_; } @@ -45,7 +45,7 @@ TEST_F(GpuHloSupportCheckerTest, Add) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK(checker().Run(module.get()).status()); + TF_ASSERT_OK(checker().Run(module).status()); } TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { @@ -60,7 +60,7 @@ TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Status status = checker().Run(module.get()).status(); + Status status = checker().Run(module).status(); ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); EXPECT_THAT(status.error_message(), HasSubstr("GPU backend does not support")); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index bca775c4750dd3aa679846d54e29a9d277adad79..96bfe0c12eb9cd6ef25804d6b34767471616f7e4 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( - ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); + auto dot1 = builder.AddInstruction( + CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 20d523abe0552f0bc22c365007c096666ec888f6..22f43bc08bd08abd735f88f32f28c528499cf3d2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -287,5 +288,42 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params) { + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); + const auto& target = custom_call->custom_call_target(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + + params->window = &custom_call->window(); + params->dnums = &custom_call->convolution_dimension_numbers(); + params->feature_group_count = custom_call->feature_group_count(); + params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + + if (target == kCudnnConvForwardCallTarget) { + params->kind = CudnnConvKind::kForward; + params->input_shape = &lhs_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &conv_result_shape; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params->kind = CudnnConvKind::kBackwardInput; + params->input_shape = &conv_result_shape; + params->filter_shape = &rhs_shape; + params->output_shape = &lhs_shape; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params->kind = CudnnConvKind::kBackwardFilter; + params->input_shape = &lhs_shape; + params->filter_shape = &conv_result_shape; + params->output_shape = &rhs_shape; + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 59c65fc2686cd4a00a3770ebaedf637e8f556828..09c455cc1e137b4a9836a58d5b70e62a4bfa120a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,9 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they // don't belong in "ir_emission_utils". @@ -148,6 +150,11 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); +// Populates params using conv, which must be a custom-call to a cudnn +// convolution. Does not modify any buffers in the params. +Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, + CudnnConvParams* params); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index ffca5d6549a8316a7c7b7946d9943f091c133d1b..b7c37bcf3ca910f10d18339dfe7f1d29f2a55c9e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -764,5 +764,20 @@ StatusOr IrEmitter::ComputeNestedElement( return Load(return_buffer); } +std::vector IrEmitter::ConstructIrArrayForOutputs( + const HloInstruction& hlo) { + std::vector output_arrays; + if (ShapeUtil::IsTuple(hlo.shape())) { + int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); + output_arrays.reserve(num_outputs); + for (int64 i = 0; i < num_outputs; ++i) { + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + } else { + output_arrays.push_back(GetIrArray(hlo, hlo)); + } + return output_arrays; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 579268f07185fd2d8ec74750f1bf833101149437..880520148005838cc25a5be9e26c8bc9028a70ce 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -124,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } + + // Generates the IrArray for each output of an hlo instruction and returns + // a vector containing such IrArrays. + std::vector ConstructIrArrayForOutputs( + const HloInstruction& hlo); + // A convenient helper for calling BufferAssignment::GetUniqueSlice. BufferAllocation::Slice GetAllocationSlice( const HloInstruction& hlo, const ShapeIndex& index = {}) const { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5c827e5f9cf3e1c04af444dae338a2ec411ce372..66c65f69758e5a2f4420935279835eaf086fea45 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { - const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); - std::vector target_arrays; - target_arrays.reserve(num_elems); - for (int64 i = 0; i != num_elems; ++i) { - target_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector target_arrays = + ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop()); - - std::vector tuple_operand_ptrs; - tuple_operand_ptrs.reserve(num_elems); - for (const llvm_ir::IrArray& array : target_arrays) { - tuple_operand_ptrs.push_back(array.GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_); return Status::OK(); } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 389a98facb9b553a91342bb7fc42642179aaf698..b669881026276eefe2ca6cbea74d79604dd13066 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -464,67 +465,35 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); - std::unique_ptr thunk; + BufferAllocation::Slice input_slice, filter_slice, output_slice; + if (target == kCudnnConvForwardCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kForward, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/conv_result_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/conv_result_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = rhs_slice; + output_slice = conv_result_slice; } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardInput, - /*input_buffer=*/conv_result_slice, - /*filter_buffer=*/rhs_slice, - /*output_buffer=*/lhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/conv_result_shape, - /*filter_shape=*/rhs_shape, - /*output_shape=*/lhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = conv_result_slice; + filter_slice = rhs_slice; + output_slice = lhs_slice; } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = absl::make_unique( - CudnnConvKind::kBackwardFilter, - /*input_buffer=*/lhs_slice, - /*filter_buffer=*/conv_result_slice, - /*output_buffer=*/rhs_slice, - /*tuple_result_buffer=*/tuple_result_slice, - /*scratch_buffer=*/scratch_slice, - /*input_shape=*/lhs_shape, - /*filter_shape=*/conv_result_shape, - /*output_shape=*/rhs_shape, // - custom_call->window(), custom_call->convolution_dimension_numbers(), - custom_call->feature_group_count(), backend_config.algorithm(), - backend_config.tensor_ops_enabled(), custom_call); + input_slice = lhs_slice; + filter_slice = conv_result_slice; + output_slice = rhs_slice; } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); } - thunk_sequence_->emplace_back(std::move(thunk)); + thunk_sequence_->emplace_back(absl::make_unique( + Cast(custom_call), input_slice, filter_slice, + output_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } @@ -2521,15 +2490,15 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index) { + HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); - const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value_operand = [&] { + HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; + HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: - return inst->operand(2); + return inst->mutable_operand(2); case HloOpcode::kReduce: - return inst->operand(1); + return inst->mutable_operand(1); case HloOpcode::kTuple: CHECK(hlo->IsMultiOutputFusion()) << ": " << hlo->ToString() << " is not a multi-output fusion."; @@ -2537,7 +2506,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( << ": Found '" << inst->operand(index.back())->opcode() << "' in " << inst->ToString() << " but expected 'reduce'."; // For multi-output fusion look through the tuple. - return inst->operand(index.back())->operand(1); + return inst->mutable_operand(index.back())->mutable_operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; @@ -2609,28 +2578,35 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ir_emitter_context_->device_description()); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); - // If the init_value was fused into this reduce we have to generate it first. - if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { - CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); - const Literal& literal = init_value_operand->literal(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, module_); + if (fused) { + // If init_value was fused into this reduce we have to generate it first. + std::vector parameter_arrays; + for (HloInstruction* operand : hlo->operands()) { + parameter_arrays.push_back(GetIrArray(*operand, *hlo)); + } + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, + ir_emitter_context_->llvm_module(), + &b_, GetNestedComputer()); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *module_, initializer->getType(), - /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, - /*Name=*/""); - global_for_const->setAlignment(kConstantBufferAlignBytes); - bindings_.BindHloToIrValue(*init_value_operand, global_for_const); + FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter)); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand), + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); + } else { + // In the unfused case the element is already there, just read from it. + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &b_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &b_) + .EmitLoop(IrName(hlo))); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter( - [=](const IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &b_); - }, - GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_) - .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2819,10 +2795,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( } // For multioutput fusion, we need to emit each operand and the root. - std::vector output_arrays; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, hlo, {i})); - } + std::vector output_arrays = ConstructIrArrayForOutputs(hlo); TF_RETURN_IF_ERROR( ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &b_, unroll_factor) @@ -2830,12 +2803,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( GetIndexTypeForKernel( &hlo, launch_dimensions.launch_bound(), &b_))); - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_); + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_); + return Status::OK(); } @@ -2847,29 +2817,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop( static_cast(LastThunk())); } -int IrEmitterUnnested::ConstructIrArrayForOutputs( - const HloInstruction& hlo, std::vector* output_arrays) { - int64 num_outputs = 1; - if (hlo.IsMultiOutputFusion()) { - num_outputs = ShapeUtil::TupleElementCount(hlo.shape()); - output_arrays->reserve(num_outputs); - for (int64 i = 0; i < num_outputs; ++i) { - output_arrays->push_back(GetIrArray(hlo, hlo, {i})); - } - } else { - output_arrays->push_back(GetIrArray(hlo, hlo)); - } - return num_outputs; -} - -int IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo, std::vector* param_arrays) { - int64 num_params = hlo.operands().size(); - param_arrays->reserve(num_params); +std::vector IrEmitterUnnested::ConstructIrArrayForInputs( + const HloInstruction& hlo) { + std::vector param_arrays; + param_arrays.reserve(hlo.operands().size()); for (const HloInstruction* param : hlo.operands()) { - param_arrays->push_back(GetIrArray(*param, hlo)); + param_arrays.push_back(GetIrArray(*param, hlo)); } - return num_params; + return param_arrays; } int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape( @@ -3050,10 +3005,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( constexpr int64 kThreadsPerTile = kTileSize * kNumRows; // Construct IrArrays for the inputs and outputs. - std::vector output_arrays; - int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays); - std::vector param_arrays; - int64 num_params = ConstructIrArrayForInputs(*hlo, ¶m_arrays); + std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); + int64 num_outputs = output_arrays.size(); + std::vector param_arrays = ConstructIrArrayForInputs(*hlo); + int64 num_params = param_arrays.size(); // Allocate shared memory buffers to store the tiled inputs. std::vector param_shmem_buffers(num_params, nullptr); @@ -3251,12 +3206,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // For multioutput fusion, emit a tuple with all the individual outputs. if (hlo->IsMultiOutputFusion()) { - std::vector tuple_operand_ptrs; - for (int64 i = 0; i < output_arrays.size(); ++i) { - tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); - } - llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_, - module_); + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_); } return launch_dimensions; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 084462330ed20108a9ec850b4cbc588afe77cc01..bd5db7205155dc6b15ddea069e172bbd8f419996 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter { LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, absl::Span reduced_output_dims, absl::Span tiled_param_ids); - // Generates the IrArray for each output of hlo and returns the number of - // outputs. - int ConstructIrArrayForOutputs(const HloInstruction& hlo, - std::vector* output_arrays); - // Generates the IrArray for each input of hlo and returns the number of - // inputs. - int ConstructIrArrayForInputs(const HloInstruction& hlo, - std::vector* param_arrays); + + // Generates the IrArray for each input of an hlo and returns a vector that + // constains such IrArrays. + std::vector ConstructIrArrayForInputs( + const HloInstruction& hlo); + // For each output of the `hlo` instruction, constructs the reduced shape for // the output with the given `reduced_output_dims` and cast the original // output IrArray element in `output_arrays` to the reduced shape. Returns @@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo, const ShapeIndex& index = {}); + HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index c822c94f1b102e02be4a13a35892a2c181702383..8a6e5327e082791ff857a89e840c6a4f045f0edb 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -259,7 +259,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { // Fusing a reduce into a loop fusion would require changing the fusion kind. // That's not supported yet. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -277,7 +277,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -301,7 +301,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -324,7 +324,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -358,7 +358,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index f6325b33680629b7e3d3814b088582a5007de6dc..dfdcf1875dd3f5749bd1fd95ad0eeb8c11955887 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -208,10 +208,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass(); - // CudnnConvolutionRewriter may add instructions of the form - // reverse(constant), which it expects will be simplified by constant - // folding. - pipeline.AddPass(); pipeline.AddPass(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass(); @@ -219,6 +215,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // pairs that TupleSimplifier fixes. pipeline.AddPass(); } + // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add + // instructions which can be simplified by constant folding. + pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index fa84d7722351b68770b876e3880b472eec3233d7..b0061fa6558ac92bffd3dff13e736421a62dc484 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -23,7 +23,6 @@ limitations under the License. namespace xla { namespace gpu { - // We want the input/output feature counts of an f16 conv to be factors of 8, // because without this cudnn can't use tensor cores on the conv. static constexpr int64 kDesiredNumFeaturesFactor = 8; @@ -63,8 +62,8 @@ static HloInstruction* PadInstruction(HloInstruction* instr, HloComputation* comp = instr->parent(); const Shape& shape = instr->shape(); - auto* zero = comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()).CloneToUnique())); + auto* zero = comp->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type()))); PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9d85d746d84908eaa8d720bc3cccc475d81710f3..2a6415d0b6c973cb72c30b7a803b5f603c1d5e4d 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -68,9 +68,8 @@ HloInstruction* MaybePaddedAndSlicedInput( conv_window.dimensions(i).base_dilation() - 1); } PrimitiveType element_type = input->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -125,9 +124,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, HloComputation* computation = kernel->parent(); PrimitiveType element_type = kernel->shape().element_type(); - HloInstruction* padding = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* padding = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,9 +234,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(input->shape().element_type()))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 091aca23e54bf0585b91e7a05c0837d8a0a2b764..c4f43cc9a614283acb376b5f98e4976615b590ad 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -21,13 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" namespace xla { namespace gpu { -class StreamAssignmentTest : public HloTestBase { +class StreamAssignmentTest : public HloVerifiedTestBase { protected: std::unique_ptr CreateNewModule() { HloModuleConfig config; @@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { /*parameter_number=*/0, f32_2x2_, /*name=*/"x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); - HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); - HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); + HloInstruction* dot1 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y)); + HloInstruction* dot2 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); - HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); - HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); - HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); - HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); - HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); - HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); - HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); - HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); + CreateCanonicalDot(f32_2x2_, params[2], params[3])); + HloInstruction* d10 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00)); + HloInstruction* d11 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4])); + HloInstruction* d20 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10)); + HloInstruction* d21 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11)); + HloInstruction* d22 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5])); + HloInstruction* d30 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21)); + HloInstruction* d31 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22)); + HloInstruction* d40 = + builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index 4550f36fdfc097632fed4956fcd3e42ef8a919c5..780539c164277f14c2bd964024f7c3ca179f4ada 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -38,8 +38,7 @@ class GpuCopyTest : public GpuCodegenTest {}; TEST_F(GpuCopyTest, UseMemcpy) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc index 9072b30317d253fd6d50e9d98949cad4eaebfe7b..f8120a5fa00ce38644cd85c54d5ef65701be1eda 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc @@ -53,40 +53,40 @@ class InfeedTest : public ClientLibraryTestBase { }; TEST_F(InfeedTest, SingleInfeedR0Bool) { - TestInfeedRoundTrip(*LiteralUtil::CreateR0(true)); + TestInfeedRoundTrip(LiteralUtil::CreateR0(true)); } TEST_F(InfeedTest, SingleInfeedR1U32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR1({1, 2, 3})); + TestInfeedRoundTrip(LiteralUtil::CreateR1({1, 2, 3})); } TEST_F(InfeedTest, SingleInfeedR2F32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); + TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64)); } TEST_F(InfeedTest, SingleInfeedR3F32) { TestInfeedRoundTrip( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0minor)); - TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout( + TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout( {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, r3_dim0major)); } TEST_F(InfeedTest, SingleInfeedR4S32) { - TestInfeedRoundTrip(*LiteralUtil::CreateR4( + TestInfeedRoundTrip(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } @@ -95,26 +95,26 @@ TEST_F(InfeedTest, SingleInfeedR4S32) { TEST_F(InfeedTest, LargeInfeed) { Array4D array(80, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D(array)); + TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D(array)); } TEST_F(InfeedTest, SingleInfeedTuple) { - TestInfeedRoundTrip( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR0(false).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR0(false)})); } TEST_F(InfeedTest, SingleInfeedEmptyTuple) { - TestInfeedRoundTrip(*LiteralUtil::MakeTuple({})); + TestInfeedRoundTrip(LiteralUtil::MakeTuple({})); } // Tests that a large tuple infeed can be handled. TEST_F(InfeedTest, SingleInfeedLargeTuple) { Array4D array(40, 100, 8, 128); array.FillIota(1.0f); - TestInfeedRoundTrip(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR4FromArray4D(array).get(), - LiteralUtil::CreateR0(5).get()})); + TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR4FromArray4D(array), + LiteralUtil::CreateR0(5)})); } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 40183de96ee363996e6b0b883a78e7a8b5d13ab2..9a61f8ac5a62e38e687a93890eb33481a01d51c8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -26,9 +26,6 @@ limitations under the License. namespace xla { namespace { -using ::testing::Eq; -using ::testing::HasSubstr; - class WhileTransformerTest : public HloTestBase { protected: WhileTransformerTest() diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index a2be89511babc23ebcd5cb40abee2a95d16dc451..ef70b688778df5115e2b5fe572d253a6948d076f 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,8 +112,11 @@ std::unique_ptr MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf170d5733d56a05106835d1eaa4f2e1..e0f3a7e0e2869fa854c0229cd06bbdd641d99363 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,14 +47,13 @@ StatusOr HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* @@ -71,13 +70,13 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +85,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +101,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index af05bedee72d4878f83765e5a5c5baf61bd71ba2..ffbf947d5ad0cf598f9de9f98f5bbe344f095993 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -88,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run( - std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run(std::unique_ptr algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,7 +126,7 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -134,21 +134,19 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -169,11 +167,11 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const HloSchedule* schedule_; const tensorflow::gtl::FlatMap* memory_by_computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 5f85f145657b67634844c849447ef545a6dea468..957c4a68915934796a315f2443c90e571e942e75 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -29,13 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; +class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { auto module = CreateNewModule(); @@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -149,10 +153,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +191,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } @@ -227,7 +233,7 @@ class HeapSimulatorTracker { HeapSimulator::Result result_; }; -class HeapSimulatorTest : public HloTestBase { +class HeapSimulatorTest : public HloVerifiedTestBase { protected: HeapSimulatorTest() {} ~HeapSimulatorTest() override {} @@ -366,8 +372,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -402,8 +408,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -440,10 +446,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -481,10 +487,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 58b7af93ebfce74951c0f2d65ab226fc94d62e4b..b19ec126382d143b6ded401f2fad56f950d04bbd 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -172,7 +172,7 @@ message HloInstructionProto { xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. - xla.PrecisionConfigProto precision_config = 51; + xla.PrecisionConfig precision_config = 51; // Collective permute field. repeated SourceTarget source_target_pairs = 52; @@ -199,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -214,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -305,6 +309,13 @@ message HeapSimulatorTrace { bool whole_module_simulation = 2; } +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that @@ -322,8 +333,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 54abe3345d25a8cc1fdd66bd6ee75157fe9b7f77..0cd0ab36fcf832af9a71ab5837c94f9b39bc4bf3 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fe7f2be888d2037e4f6d3879bcc716de4eee07f9..8c6903d76628f87b01de044f1e49de367bf38110 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -464,6 +464,14 @@ std::vector HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } @@ -552,9 +562,11 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + auto computation = absl::WrapUnique( + new HloComputation(proto.name(), parameter_count, &instructions, root, + /*fusion_instruction=*/nullptr)); + computation->unique_id_ = proto.id(); + return std::move(computation); } void HloComputation::FuseInstructionsInto( diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fe2d3bbbe53bdcb7b2ea8a35f35e50fb3e8823b4..91c5234a6fde6698c5d600d667e3370d44134a50 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -170,6 +170,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index f7ed1b0316b213a0f34b1d690229f0173dbd5250..2aaaef1d36d58bcce18db4aa37ff05ea352e484b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 8a45939c61755876555bc35c49d7d6c781f8b4fe..f837816cea78d78bb3d605dd91e81cac39036268 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,10 +76,10 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } - std::unique_ptr result = evaluator->TryEvaluate(instruction); + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. - if (result == nullptr) { + if (!evaluator->TryEvaluate(instruction, &result)) { VLOG(2) << "Constant folding failed for instruction: " << instruction->ToString(); continue; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 07cd1efc1208309770478885532e0284bdb1fbcc..3e0def5d26a0033d954a776c1c32d6c35acfb505 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" @@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloVerifiedTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); @@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_THAT(computation->root_instruction(), op::Convert(input)); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); @@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -175,7 +175,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->Literal::CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { auto computation = module->AddEntryComputation(builder.Build()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module)); EXPECT_TRUE(result); HloInstruction* root = computation->root_instruction(); @@ -198,7 +198,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { root->literal().EachCell( [&](absl::Span indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == literal_clone->Get(rindexes)); + matched = matched && (value == literal_clone.Get(rindexes)); }); EXPECT_TRUE(matched); } @@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"( })"; TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); + ParseAndVerifyModule(kConstantFoldReduce); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_TRUE(result); - EXPECT_EQ(6, module->entry_computation() + EXPECT_EQ(6, module() + .entry_computation() ->root_instruction() ->literal() .GetFirstElement()); } TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(kConstantFoldReduce)); - HloInstruction* add = module->computations().begin()->root_instruction(); + ParseAndVerifyModule(kConstantFoldReduce); + HloInstruction* add = module().computations().begin()->root_instruction(); LayoutUtil::ClearLayout(add->mutable_shape()); HloConstantFolding const_folder; - TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module())); EXPECT_FALSE(result); - EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 939b5114c3f8f93ad2d768e77db302ae83e44d17..a502fff9a0f1e40065746f2193bf76b1adefdb31 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) { + // Domain does not have any computation or data transfer. + current_should_compute_bottleneck_time_ = false; + current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; + return Status::OK(); +} + Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); @@ -507,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { valid_position_counts.push_back(valid_position_count); } - const int64 fma_count = - input_feature * output_feature * batch * Product(valid_position_counts); + const int64 fma_count = (input_feature / convolution->feature_group_count()) * + output_feature * batch * + Product(valid_position_counts); current_properties_[kFlopsKey] = fma_count * kFmaFlops; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 9bb3f12ee2c7867d71de61c5077f129fdf59ef75..46b4bbeef222e6de581360fc01b293e812f1dedd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleRecvDone(const HloInstruction* recv_done) override; Status HandleConvert(const HloInstruction* convert) override; Status HandleCopy(const HloInstruction* copy) override; + Status HandleDomain(const HloInstruction* domain) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 2c854eea18642eb7cb081b4fdfe3bc83627e41ae..d76ce9ecbca67ae3bc3db4ee2452f30ccec5b88b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) { sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18)); } +TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { + XlaBuilder builder("convolution"); + auto input = Parameter( + &builder, 0, + ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10, + /*x_dim=*/20}), + "input"); + auto kernel = Parameter( + &builder, 1, + ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3, + /*x_dim=*/3}), + "kernel"); + Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120); + + // Run HLO cost analysis. + auto hlo_module = BuildHloGraph(&builder); + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + // Output shape is [1x120x8x18] and each output element requires (3x3) + // FMAs and one FMA is 2 flops. + EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3); + + // Bytes accessed is sum of inputs and output. + EXPECT_EQ(analysis.bytes_accessed(), + sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18)); +} + TEST_F(HloCostAnalysisTest, Reduce) { XlaBuilder builder("reduce"); auto input = @@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - XlaBuilder builder("matmul"); + XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y"); Tuple(&builder, {x, y}); @@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); } +using DomainCostAnalysis = HloTestBase; +TEST_F(DomainCostAnalysis, DomainCost) { + HloCostAnalysis analysis(ShapeSize); + + HloComputation::Builder builder("domain"); + auto x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {123}), "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y")); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y})); + auto domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain); + ASSERT_IS_OK(domain->Accept(&analysis)); + + EXPECT_EQ(analysis.flop_count(*domain), 0); + EXPECT_EQ(analysis.transcendental_count(*domain), 0); + EXPECT_EQ(analysis.bytes_accessed(*domain), 0); +} + TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { XlaBuilder builder("BaseDilatedConvolution"); auto input = Parameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 19ffb465c04ccc720ba6a8a14b187691a62b5c24..b76c50bb5b99cf4c9e6d4e04c240e8159acfc338 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -61,15 +61,18 @@ StatusOr MakeSliceHlo(HloInstruction* operand, } StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), - window, dimension_numbers)); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, + window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, window, dimension_numbers)); + convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, @@ -165,14 +168,15 @@ StatusOr MakeConcatHlo( } StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers) { + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); - return computation->AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); + return computation->AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dim_numbers, precision_config)); } StatusOr MakeMapHlo(absl::Span operands, @@ -317,18 +321,17 @@ StatusOr PadVectorWithZeros(HloInstruction* operand, padding_config_dim.set_edge_padding_high(zeros_to_append); *padding_config.add_dimensions() = padding_config_dim; - HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(absl::make_unique( - LiteralUtil::Zero(operand->shape().element_type())))); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(operand->shape().element_type()))); return MakePadHlo(operand, zero, padding_config); } StatusOr BroadcastZeros( HloComputation* computation, PrimitiveType element_type, absl::Span broadcast_dimensions) { - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - absl::make_unique(LiteralUtil::Zero(element_type)))); + HloInstruction* zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index a1c4b374d1121bbf94f5940b52859682808119c4..b22058abb4dcbf17631f28e4eacf6c7f1da781d2 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -48,8 +48,9 @@ StatusOr MakeSliceHlo(HloInstruction* operand, // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -98,7 +99,8 @@ StatusOr MakeConcatHlo( // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers); + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index eb6affadc800d9d5cf7b143386b46f3e8c608e63..e07a196d1154dc0ea45ccd2f15b0b9b56f7c41f8 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -57,10 +57,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { entry_computation->set_root_instruction(first_1_dims_collapsed); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({3, 4})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({3, 4})); } TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { @@ -78,13 +78,13 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2( + CHECK_EQ(result_literal, + LiteralUtil::CreateR2( {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); } @@ -103,10 +103,10 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9, 10}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9, 10}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { @@ -124,10 +124,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR1({9, 10})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR3({{{9, 10}}})); + Literal result_literal, + evaluator.Evaluate(*module, + {LiteralUtil::CreateR1({9, 10})})); + CHECK_EQ(result_literal, LiteralUtil::CreateR3({{{9, 10}}})); } TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { @@ -144,10 +144,10 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(9)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{9}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(9)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{9}})); } TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { @@ -165,11 +165,11 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { HloEvaluator evaluator; TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result_literal, - evaluator.Evaluate>( + Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6})})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); } TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { @@ -187,10 +187,10 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { entry_computation->set_root_instruction(zero_padded_param); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR1({3, 4})})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); + CHECK_EQ(result_literal, LiteralUtil::CreateR1({0, 0, 0, 3, 4, 0})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { @@ -208,10 +208,10 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( - *module, {LiteralUtil::CreateR0(0)})); - CHECK_EQ(*result_literal, *LiteralUtil::CreateR2({{0, 0}, {0, 0}})); + TF_ASSERT_OK_AND_ASSIGN( + Literal result_literal, + evaluator.Evaluate(*module, {LiteralUtil::CreateR0(0)})); + CHECK_EQ(result_literal, LiteralUtil::CreateR2({{0, 0}, {0, 0}})); } TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { @@ -229,11 +229,11 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { entry_computation->set_root_instruction(zeros); HloEvaluator evaluator; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, - evaluator.Evaluate>( + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + evaluator.Evaluate( *module, {LiteralUtil::CreateR0(0.0f)})); - CHECK_EQ(*result_literal, - *LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); + CHECK_EQ(result_literal, + LiteralUtil::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index cb367adf5ef29111838dd6ee1b770394eef1301c..b59c9ba3ed7990eb2a35abc83f87b25a1b1e7c60 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 406d712ec6783a310aabc6600b8b70e1a1ae30a9..9b18b0284f63c25934c1b7118dc8973caa62cadc 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" @@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class HloCseTest : public HloTestBase { +class HloCseTest : public HloVerifiedTestBase { protected: HloCseTest() {} }; @@ -65,15 +65,15 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(3, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get({})); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR0(84.0); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -96,16 +96,16 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); auto first_operand = add->operand(0); EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); EXPECT_THAT(add, op::Add(first_operand, first_operand)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -128,14 +128,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); EXPECT_THAT(add, op::Add(constant1, constant2)); - auto result = ExecuteAndTransfer(std::move(module), {}); + auto result = ExecuteAndTransfer(module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { @@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); // CSE will remove both the second float(42.0f) and the corresponding // convert/cast. @@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) { op::Tuple(common_constant1, common_constant2, uncommon_constant)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test two identical while loops with same inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); } // Test two while loops with same conditions, same inputs, but different // bodies TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[] condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body2 } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } // Test two identical while loops with different inputs TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[] condition=%condition.1, body=%body } - )") - .ValueOrDie(); + )"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(8, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(8, computation->instruction_count()); } // Test two while loops with identical bodies and same inputs, but different // conditions TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions %body (param: (f32[], f32[])) -> (f32[], f32[]) { @@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2) %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1), condition=%condition.1, body=%body - })") - .ValueOrDie(); + })"); - auto computation = module->entry_computation(); + auto computation = module().entry_computation(); EXPECT_EQ(5, computation->instruction_count()); HloCSE cse(true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(&module()).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); } @@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/true); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); @@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(3, computation->instruction_count()); auto first_operand = tuple->operand(0); @@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) { EXPECT_EQ(5, fused_computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(4, fused_computation->instruction_count()); auto root = fused_computation->root_instruction(); @@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + EXPECT_TRUE(cse.Run(module).ValueOrDie()); EXPECT_EQ(5, computation->instruction_count()); auto operand = tuple->operand(0); @@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) { uint32 count_before = computation->instruction_count(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); uint32 count_after = computation->instruction_count(); EXPECT_EQ(count_before, count_after); @@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { VLOG(3) << "before: " << module->ToString(); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); VLOG(3) << "after: " << module->ToString(); @@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule m add_computation { @@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) { r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 ROOT f2 = (f32[],f32[]) tuple(r1, r2) - })") - .ValueOrDie(); + })"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0), root->operand(1)); } @@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { EXPECT_EQ(2, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(cse.Run(module).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); } TEST_F(HloCseTest, Domain) { - auto module = ParseHloString(R"( + ParseAndVerifyModule(R"( HloModule module ENTRY %entry { %param = f32[] parameter(0), sharding={maximal device=0} @@ -735,13 +730,11 @@ ENTRY %entry { domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}} %add = f32[] add(%domain.3, %domain.4) ROOT %sub = f32[] subtract(%add, %domain.5) -})") - .ValueOrDie(); +})"); HloCSE cse(/*is_layout_sensitive=*/false); - EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - LOG(INFO) << "AAAAA " << module->ToString(); - const HloInstruction* sub = module->entry_computation()->root_instruction(); + EXPECT_TRUE(cse.Run(&module()).ValueOrDie()); + const HloInstruction* sub = module().entry_computation()->root_instruction(); const HloInstruction* add = sub->operand(0); EXPECT_EQ(add->operand(0), add->operand(1)); EXPECT_NE(add->operand(0), sub->operand(1)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index d1a96c10f88e3c05e21a6db4eccb46683cd64c4a..510d6360a1cf94ef06d2ed919a57c7a825886834 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); @@ -2334,8 +2337,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 8b2846e0c277b3e7cffd578d988d0a09c13833ed..113fd18eae70f0a581e2ab3e44544c47fcab3361 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) { CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + tensorflow::gtl::FlatMap + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 633109249a91eec3d7b4cbe5b423b73f980217c9..56b557d7cea424f63cd4891661ae446133ee5a37 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -69,6 +69,11 @@ class HloDomainMap { // instruction is not found within any domain. int64 GetDomainId(HloInstruction* instruction) const; + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(HloInstruction* instruction) const; + private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. @@ -109,9 +114,14 @@ class HloDomainMap { const tensorflow::gtl::FlatSet& instruction_set, const InstructionOrderMap& instructions_order); + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); + string domain_kind_; std::vector> instruction_domains_; tensorflow::gtl::FlatMap instruction_to_domain_; + tensorflow::gtl::FlatMap domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 6c142ee47421049e8a25dfb80a6297e02fe782f1..302807f816e4ab626af419023e7740fd6bde795f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -72,6 +72,9 @@ class DomainMetadata { // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 974ab94467dfb63325698b4590dac1abd1ed9f89..43e74d2f6f07bd685ad8683401138a4f06cd2ad2 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata { static absl::string_view KindName() { return "opname"; } + size_t Hash() const override { return std::hash()(opname_); } + private: string opname_; }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 441dcad00047311d682c0623964ee63aab341904..06b6d5b5592c5849dd247fc19fc52ab0a2113fe8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -53,11 +53,9 @@ namespace xla { namespace { - template -StatusOr> Compare(const Shape& shape, HloOpcode opcode, - LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -95,9 +93,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -106,9 +104,9 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, } template <> -StatusOr> Compare( - const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, - LiteralSlice rhs_literal) { +StatusOr Compare(const Shape& shape, HloOpcode opcode, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -126,9 +124,9 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); })); @@ -194,7 +192,7 @@ HloEvaluator::HloEvaluator(int64 max_loop_iterations) } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); @@ -207,11 +205,21 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) - .CloneToUnique(); + .Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloModule& module, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(module, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -225,11 +233,21 @@ StatusOr> HloEvaluator::Evaluate( } TF_RETURN_IF_ERROR(computation.Accept(this)); - return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); +} + +template <> +StatusOr HloEvaluator::Evaluate( + const HloComputation& computation, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal_ptr : arg_literals) { + arg_literal_ptrs.push_back(&literal_ptr); + } + return Evaluate(computation, arg_literal_ptrs); } template -StatusOr> HloEvaluator::Evaluate( +StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); @@ -248,18 +266,27 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = input_literal->CloneToUnique(); + evaluated_[operand] = input_literal->Clone(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -StatusOr> HloEvaluator::Evaluate( - HloInstruction* instruction) { +template <> +StatusOr HloEvaluator::Evaluate( + HloInstruction* instruction, absl::Span arg_literals) { + std::vector arg_literal_ptrs; + for (const auto& literal : arg_literals) { + arg_literal_ptrs.push_back(&literal); + } + return Evaluate(instruction, arg_literal_ptrs); +} + +StatusOr HloEvaluator::Evaluate(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kParameter) { return tensorflow::errors::FailedPrecondition( "Cannot evaluate a parameter."); @@ -275,21 +302,22 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return GetEvaluatedLiteralFor(instruction).CloneToUnique(); + return GetEvaluatedLiteralFor(instruction).Clone(); } -std::unique_ptr HloEvaluator::TryEvaluate( - HloInstruction* instruction) { +bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) { + CHECK(result != nullptr); auto result_or = Evaluate(instruction); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); - return nullptr; + return false; } - return result_or.ConsumeValueOrDie(); + *result = result_or.ConsumeValueOrDie(); + return true; } -StatusOr> HloEvaluator::EvaluateWithSubstitutions( +StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions) { @@ -300,7 +328,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( owned_operands.push_back(operand->Clone()); } else { owned_operands.push_back( - HloInstruction::CreateConstant(it->second->CloneToUnique())); + HloInstruction::CreateConstant(it->second->Clone())); } } @@ -317,12 +345,12 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( +StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), @@ -332,10 +360,10 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( +StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = - HloInstruction::CreateConstant(operand.CloneToUnique()); + HloInstruction::CreateConstant(operand.Clone()); std::unique_ptr cloned_instruction = HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); @@ -344,13 +372,14 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr> HloEvaluator::EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, +StatusOr HloEvaluator::EvaluateDotOp( + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = - HloInstruction::CreateConstant(lhs.CloneToUnique()); + HloInstruction::CreateConstant(lhs.Clone()); std::unique_ptr rhs_instr = - HloInstruction::CreateConstant(rhs.CloneToUnique()); + HloInstruction::CreateConstant(rhs.Clone()); TF_ASSIGN_OR_RETURN( Shape dot_shape, @@ -358,7 +387,7 @@ StatusOr> HloEvaluator::EvaluateDotOp( std::unique_ptr cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), - dim_numbers); + dim_numbers, precision_config); return Evaluate(cloned_instruction.get()); } @@ -371,7 +400,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { << ", but input literal shape is: " << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = input_literal->CloneToUnique(); + evaluated_[parameter] = input_literal->Clone(); return Status::OK(); } @@ -421,7 +450,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -824,7 +853,7 @@ class OutputOffsetIndexToInputIndex { // there is one) to `reshaped_start_indices`. static StatusOr> ReshapedGatherIndices( int64 index_vector_dim, const Literal& start_indices, - std::unique_ptr* reshaped_start_indices) { + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -834,16 +863,16 @@ static StatusOr> ReshapedGatherIndices( new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_start_indices, start_indices.Reshape(new_shape)); - return std::cref(**reshaped_start_indices); + return std::cref(*reshaped_start_indices); } Status HloEvaluator::HandleGather(HloInstruction* gather) { - std::unique_ptr result = Literal::CreateFromShape(gather->shape()); + Literal result = Literal::CreateFromShape(gather->shape()); const Shape& shape = gather->shape(); const GatherDimensionNumbers& dim_numbers = gather->gather_dimension_numbers(); const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0)); - std::unique_ptr reshaped_start_indices; + Literal reshaped_start_indices; TF_ASSIGN_OR_RETURN( const Literal& start_indices, ReshapedGatherIndices(dim_numbers.index_vector_dim(), @@ -908,7 +937,7 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { DCHECK_LT(input_index[i], operand_shape.dimensions(i)); } TF_RETURN_IF_ERROR( - result->CopyElementFrom(operand, input_index, output_index)); + result.CopyElementFrom(operand, input_index, output_index)); return true; }; @@ -940,8 +969,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { // Checks that operand's dimensions are the same as the broadcast's // dimensions along the dimensions to be broadcasted. for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand.shape().dimensions(i)); + auto operand_dim_size = operand.shape().dimensions(i); + auto broadcast_dim_size = + broadcast->shape().dimensions(broadcast->dimensions(i)); + TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat( + "Operand dimension %d is broadcast to output dimension %d, but the " + "sizes of these two dims do not match (%d vs %d): %s", + i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size, + broadcast->ToString()); } TF_ASSIGN_OR_RETURN( @@ -971,18 +1006,16 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = absl::make_unique( - ShapeUtil::GetTupleElementShape(operand->shape(), index)); - return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, - /*dest_shape_index=*/{}, - /*src_shape_index=*/{index}); + evaluated_[get_tuple_element] = + Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - - auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); - evaluated_[copy] = std::move(result); + evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone(); return Status::OK(); } @@ -998,7 +1031,7 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); @@ -1030,7 +1063,7 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { } HloEvaluator embedded_evaluator; - std::unique_ptr result = + Literal result = embedded_evaluator .Evaluate(*readded_computation, arg_literals) .ConsumeValueOrDie(); @@ -1050,7 +1083,7 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* false_computation = conditional->false_computation(); HloEvaluator embedded_evaluator; - std::unique_ptr result; + Literal result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -1075,9 +1108,9 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { - evaluated_[select] = on_true.CloneToUnique(); + evaluated_[select] = on_true.Clone(); } else { - evaluated_[select] = on_false.CloneToUnique(); + evaluated_[select] = on_false.Clone(); } return Status::OK(); } @@ -1091,9 +1124,9 @@ Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) { const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2)); if (pred.Get({})) { - evaluated_[tuple_select] = on_true.CloneToUnique(); + evaluated_[tuple_select] = on_true.Clone(); } else { - evaluated_[tuple_select] = on_false.CloneToUnique(); + evaluated_[tuple_select] = on_false.Clone(); } return Status::OK(); } @@ -1102,7 +1135,7 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloComputation* cond_comp = while_hlo->while_condition(); HloComputation* body_comp = while_hlo->while_body(); // Initialize the loop carried valued with the input to the While instruction. - auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique(); + auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); bool keep_going = true; int64 iteration_count = 0; HloEvaluator cond_evaluator(max_loop_iterations_); @@ -1112,13 +1145,13 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", while_hlo->name(), max_loop_iterations_); } - TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( - *cond_comp, {lcv.get()})); - keep_going = cond_val->GetFirstElement(); + TF_ASSIGN_OR_RETURN(auto cond_val, + cond_evaluator.Evaluate(*cond_comp, {&lcv})); + keep_going = cond_val.GetFirstElement(); if (keep_going) { TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate( - *body_comp, {lcv.get()})); - VLOG(3) << "Loop iteration result: " << body_val->ToString(); + *body_comp, {&lcv})); + VLOG(3) << "Loop iteration result: " << body_val.ToString(); lcv = std::move(body_val); cond_evaluator.ResetVisitStates(); loop_body_evaluator.ResetVisitStates(); @@ -1133,9 +1166,9 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { // hoops to make this work. namespace { template -StatusOr> EvaluateSortInternal( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortInternal(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { auto rank = ShapeUtil::Rank(keys_literal.shape()); TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) @@ -1173,57 +1206,55 @@ StatusOr> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = absl::make_unique(keys_literal.shape()); - result_keys_literal->PopulateR1(absl::Span(result_keys)); - auto result_values_literal = - absl::make_unique(values_literal.shape()); - result_values_literal->PopulateR1( + Literal result_keys_literal(keys_literal.shape()); + result_keys_literal.PopulateR1(absl::Span(result_keys)); + Literal result_values_literal(values_literal.shape()); + result_values_literal.PopulateR1( absl::Span(result_values)); return std::make_pair(std::move(result_keys_literal), std::move(result_values_literal)); }; - std::unique_ptr result_tuple; + Literal result_tuple; if (rank == 1) { auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = LiteralUtil::MakeTuple( - {result_pair.first.get(), result_pair.second.get()}); + result_tuple = + LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = absl::make_unique(keys_literal.shape()); - auto values_result_literal = - absl::make_unique(values_literal.shape()); + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); + .Reshape({r1_length})); TF_ASSIGN_OR_RETURN(auto values_r1_slice, values_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice); + .Reshape({r1_length})); + auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first->Reshape({1, r1_length})); + r1_result_pair.first.Reshape({1, r1_length})); TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom( - *sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom( - *sorted_values, {0, 0}, {row, 0}, {1, r1_length})); + r1_result_pair.second.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values, {0, 0}, {row, 0}, {1, r1_length})); } - result_tuple = LiteralUtil::MakeTuple( - {keys_result_literal.get(), values_result_literal.get()}); + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); } - VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString(); + VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } template -StatusOr> EvaluateSortCurried( - HloInstruction* sort, const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSortCurried(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(1)->shape().element_type()) { case F32: return EvaluateSortInternal(sort, keys_literal, @@ -1242,9 +1273,9 @@ StatusOr> EvaluateSortCurried( } } -StatusOr> EvaluateSort(HloInstruction* sort, - const Literal& keys_literal, - const Literal& values_literal) { +StatusOr EvaluateSort(HloInstruction* sort, + const Literal& keys_literal, + const Literal& values_literal) { switch (sort->operand(0)->shape().element_type()) { case F32: return EvaluateSortCurried(sort, keys_literal, values_literal); @@ -1308,33 +1339,25 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) { Status HloEvaluator::Postprocess(HloInstruction* hlo) { VLOG(2) << "Finished visiting " << hlo->ToString() << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString(); + // Out of convenience the literal may have been produced with a different + // layout. Relayout as indicated by the HLO instruction. + if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(), + hlo->shape())) { + evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape()); + } return Status::OK(); } // Explicit instantiation of templatized Evaluate* methods. // -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( const HloModule& module, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - const HloModule& module, - absl::Span> arg_literals); - -template StatusOr> HloEvaluator::Evaluate< - const Literal*>(const HloComputation& computation, - absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( + +template StatusOr HloEvaluator::Evaluate( const HloComputation& computation, - absl::Span> arg_literals); + absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate( +template StatusOr HloEvaluator::Evaluate( HloInstruction* instruction, absl::Span arg_literals); -template StatusOr> -HloEvaluator::Evaluate>( - HloInstruction* instruction, - absl::Span> arg_literals); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c2d49e56ac487ee8a5cb3d26aee497ade63aa844..21e676d671af08d1626ca6f157db63bf8d23ae0b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -47,11 +47,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloModule& module, absl::Span arg_literals); + StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -69,12 +69,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - const HloComputation& computation, - absl::Span arg_literals); + StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -82,41 +81,43 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // `LiteralPtr` accepts either Literal or const Literal* // type. template - StatusOr> Evaluate( - HloInstruction* instruction, absl::Span arg_literals); + StatusOr Evaluate(HloInstruction* instruction, + absl::Span arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. // Precondition: // 1. all operands of the input instruction are constants. // 2. the instruction is not a Parameter operation. - StatusOr> Evaluate(HloInstruction* instruction); + StatusOr Evaluate(HloInstruction* instruction); - // Same as Evaluate, except returning nullptr on error. - std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Same as Evaluate, except returning false on error and accepts an output + // pointer. + bool TryEvaluate(HloInstruction* instruction, Literal* result); // Evaluates a single HLO instruction, substituting the given literals for // some of the instruction's operands. // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr> EvaluateWithSubstitutions( + StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const std::unordered_map& substitutions); - StatusOr> EvaluateElementwiseBinaryOp( - HloOpcode opcode, const Literal& lhs, const Literal& rhs); + StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr> EvaluateElementwiseUnaryOp( - HloOpcode opcode, const Literal& operand); + StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, - const Literal& rhs); + StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this @@ -196,7 +197,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { auto it = evaluated_.find(hlo); CHECK(it != evaluated_.end()) << "could not find evaluated value for: " << hlo->ToString(); - return *(it->second); + return it->second; } // Tracks the HLO instruction and its evaluated literal result. @@ -204,12 +205,13 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // that are no longer a parent for any other subsequent instruction in // post-orderring. // Must be cleared for each evaluation. - tensorflow::gtl::FlatMap> - evaluated_; + // Storing Literal in place require the container to have pointer stability so + // we cannot use FlatMap any more. + std::unordered_map evaluated_; private: template - static StatusOr> ElementWiseUnaryOpImpl( + static StatusOr ElementWiseUnaryOpImpl( HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { @@ -226,9 +228,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { ShapeUtil::HumanString(operand->shape())); } - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 7e490d7f324022fdf02c569fc1986d0b6f5823ba..01e88566a55dbfddaaec5db6100327a8c1db398b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,15 +52,11 @@ static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: - HloEvaluatorTest() - : HloVerifiedTestBase(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false), - use_bfloat16_(GetParam()) { + HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } - std::unique_ptr Evaluate( - absl::Span arg_literals = {}) { + Literal Evaluate(absl::Span arg_literals = {}) { if (use_bfloat16_) { // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. auto type_converter = HloElementTypeConverter(F32, BF16); @@ -72,39 +68,37 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr evaluator_; - void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr input, float aabs = 0) { + void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, + float aabs = 0) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - b.AddInstruction( - HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - auto element_type = expected->shape().element_type(); + auto element_type = expected.shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error)); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } } - void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, - std::unique_ptr lhs, - std::unique_ptr rhs) { + void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs, + Literal rhs) { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); b.AddInstruction( - HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } bool use_bfloat16_; @@ -120,7 +114,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - Shape shape = low->shape(); + Shape shape = low.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -129,11 +123,11 @@ TEST_P(HloEvaluatorTest, DoesClamp) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -141,7 +135,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto value = LiteralUtil::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = LiteralUtil::CreateR0(1.f); - Shape shape = value->shape(); + Shape shape = value.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); @@ -150,11 +144,11 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{0, 0}, {1, 1}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -164,7 +158,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); - Shape shape = on_true->shape(); + Shape shape = on_true.shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred))); auto c2 = @@ -175,11 +169,11 @@ TEST_P(HloEvaluatorTest, DoesSelect) { HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -298,7 +292,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); - std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; + std::vector args = {&lhs, &rhs, &rhs2}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -316,11 +310,11 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { lhs_instruction, param_rhs2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(args); + Literal result = Evaluate(args); auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies Reshape operation is correctly evaluated. @@ -330,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { TF_ASSERT_OK_AND_ASSIGN(auto literal, LiteralUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = literal->CloneToUnique(); + auto literal_clone = literal.Clone(); HloInstruction* literal_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(literal))); @@ -340,14 +334,13 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - result->EachCell( - [&](absl::Span indices, NativeT value) { - std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); - }); + result.EachCell([&](absl::Span indices, NativeT value) { + std::vector rindexes = Permute(permutation, indices); + EXPECT_NEAR(value, literal_clone.Get(rindexes), 0.031250); + }); } // Verifies Broadcast operation is correctly evaluated. @@ -359,12 +352,12 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { HloInstruction* literal_instruction = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, {1, 2})); + output_literal.shape(), literal_instruction, {1, 2})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -377,13 +370,13 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloInstruction::CreateConstant(std::move(input_literal))); // Broadcast dimension should be empty in the case of scalars. b.AddInstruction(HloInstruction::CreateBroadcast( - output_literal->shape(), literal_instruction, + output_literal.shape(), literal_instruction, /*broadcast_dimensions=*/{})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate({}); + Literal result = Evaluate({}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -401,11 +394,11 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -423,10 +416,10 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({100, 200}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -435,17 +428,17 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { auto input_literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto expected = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); - ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -455,17 +448,17 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); auto expected = LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); - ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), - expected->shape())); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(), + expected.shape())); HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); - b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); + b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } PaddingConfig CreatePaddingConfig( @@ -498,12 +491,12 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +518,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -538,7 +531,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = LiteralUtil::CreateR4FromArray4D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -569,7 +562,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = absl::make_unique>(1, 5); @@ -580,7 +573,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -614,12 +607,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = absl::make_unique>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -649,10 +642,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -664,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -694,14 +688,15 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({22.f, 28.f}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,10 +732,11 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -750,7 +746,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -788,17 +784,18 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { dnums.set_kernel_input_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -842,12 +839,13 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -860,7 +858,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -925,12 +923,13 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -940,7 +939,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1002,12 +1001,13 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] @@ -1017,7 +1017,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = LiteralUtil::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1061,12 +1061,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1080,7 +1081,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1124,12 +1125,13 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1144,7 +1146,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, @@ -1195,12 +1197,13 @@ TEST_P(HloEvaluatorTest, ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1216,7 +1219,68 @@ TEST_P(HloEvaluatorTest, })); auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { + HloComputation::Builder b(TestName()); + std::vector input_dims = {1, 2, 2, 4}; + std::vector filter_dims = {2, 2, 2, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), -7); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), -31); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, + /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + Array4D expected_array(1, 1, 1, 8); + expected_array.FillWithYX( + Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1249,9 +1313,8 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) { module().AddEntryComputation(b.Build()); HloEvaluator hlo_eval; - std::unique_ptr result = - hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); - LiteralTestUtil::ExpectR0Equal(kNumElements, *result); + Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie(); + LiteralTestUtil::ExpectR0Equal(kNumElements, result); } // Reducing many numbers should be fast because it doesn't create @@ -1328,11 +1391,11 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR1({6, 18}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1380,10 +1443,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{6, 7}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1437,10 +1500,10 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({{1, 3, 5}, {5, 11, 13}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1448,7 +1511,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = @@ -1498,12 +1561,12 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; - std::unique_ptr result_literal = + Literal result_literal = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 8.0f); - EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1530,14 +1593,14 @@ TEST_P(HloEvaluatorTest, StridedSlice) { /*strides=*/{2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {3}, {19}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1564,14 +1627,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1600,14 +1663,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { start_indices, {2, 3})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {2, 3, 4}, {6, 7, 8}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1637,14 +1700,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { shape, operand, update, start_indices)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, -2, -3}, {5, -6, -7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1673,14 +1736,14 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto expected = LiteralUtil::CreateR2({ {1, 2, 3}, {5, 6, 7}, }); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1712,16 +1775,14 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); auto result_inner_literal = LiteralUtil::CreateR2FromArray2D(*operand_array); - auto expected = LiteralUtil::MakeTuple({ - result_inner_literal.get(), - result_inner_literal.get(), - }); + auto expected = + LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1752,7 +1813,7 @@ TEST_P(HloEvaluatorTest, Reverse) { b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); module().AddEntryComputation(b.Build()); - std::unique_ptr result = Evaluate(); + Literal result = Evaluate(); // clang-format off auto expected = LiteralUtil::CreateR4FromArray4D({ @@ -1774,7 +1835,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1790,12 +1851,13 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. HloEvaluator evaluator; + Literal param0_literal = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); auto result = evaluator.EvaluateWithSubstitutions( - add, {{param0, LiteralUtil::CreateR1({1, 2, 3, 4}).get()}, - {square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + add, {{param0, ¶m0_literal}, {square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1815,11 +1877,12 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { // Evaluate add with square = {10, 20, 30, 40}. HloEvaluator evaluator; - auto result = evaluator.EvaluateWithSubstitutions( - add, {{square, LiteralUtil::CreateR1({10, 20, 30, 40}).get()}}); + Literal square_literal = LiteralUtil::CreateR1({10, 20, 30, 40}); + auto result = + evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}); TF_ASSERT_OK(result.status()); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); + LiteralUtil::CreateR1({11, 22, 33, 44}), result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1838,12 +1901,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {7, 8, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1862,12 +1925,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralUtil::CreateR2({{1, 3}, {4, 6}, {7, 9}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1886,14 +1949,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3( + LiteralUtil::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), start_indices.get()}))); + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1912,15 +1974,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, 1}, {-4, 4}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, @@ -1940,15 +2001,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-2, 2}, {-1, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1967,12 +2027,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{5}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{5}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1991,13 +2050,12 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{8}}, {{5}}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2016,11 +2074,10 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{}, {}}), - *Evaluate({operand.get(), start_indices.get()}))); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2({{}, {}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { @@ -2040,12 +2097,12 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr start_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal start_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR2({{0, 1}, {2, 1}}), - *Evaluate({operand.get(), start_indices.get()}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{0, 1}, {2, 1}}), + Evaluate({&operand, &start_indices}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) { @@ -2070,15 +2127,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) { @@ -2103,15 +2158,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) { @@ -2137,15 +2191,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) { @@ -2171,15 +2223,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) { @@ -2205,17 +2255,15 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2( + LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}), - ErrorSpec{0.1, 0.01})); + Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01})); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) { @@ -2241,15 +2289,13 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) { @@ -2275,15 +2321,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + LiteralUtil::CreateR2({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}), + Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) { @@ -2308,21 +2353,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + Literal expected = LiteralUtil::CreateR3({{{-10, 10}, {-2, 2}, {-3, 3}}, // {{-40, 40}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, @@ -2348,21 +2390,18 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + Literal expected = LiteralUtil::CreateR3({{{-20, 20}, {-10, 10}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) { @@ -2387,16 +2426,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) { @@ -2421,17 +2458,14 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - std::unique_ptr expected = + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + Literal expected = LiteralUtil::CreateR2({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) { @@ -2456,13 +2490,11 @@ ENTRY main { } )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); EXPECT_TRUE(LiteralTestUtil::Equal( - *operand, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + operand, Evaluate({&operand, &scatter_indices, &updates}))); } TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) { @@ -2489,16 +2521,13 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - std::unique_ptr expected = - LiteralUtil::CreateR1({10, 61, 32}); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + Literal expected = LiteralUtil::CreateR1({10, 61, 32}); EXPECT_TRUE(LiteralTestUtil::Equal( - *expected, - *Evaluate({operand.get(), scatter_indices.get(), updates.get()}))); + expected, Evaluate({&operand, &scatter_indices, &updates}))); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -2535,11 +2564,29 @@ ENTRY main { )"; ParseAndVerifyModule(hlo_text); - std::unique_ptr arg = LiteralUtil::CreateR1( + Literal arg = LiteralUtil::CreateR1( {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)}); - std::unique_ptr expected = - LiteralUtil::CreateR0(bfloat16(44.0f)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()}))); + Literal expected = LiteralUtil::CreateR0(bfloat16(44.0f)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg}))); +} + +TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) { + // Regression test for b/114735354. + const string hlo_text = R"( +HloModule SliceWithDifferentLayout + +ENTRY main { + arg = f32[2,2,2]{0,1,2} parameter(0) + ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]} +} +)"; + ParseAndVerifyModule(hlo_text); + + Literal arg = LiteralUtil::CreateR3WithLayout( + {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + LayoutUtil::MakeLayout({0, 1, 2})); + Literal actual = Evaluate({&arg}); + EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index cb27e13e99c0192a9796d3d32eba2637e7db06bc..8fb17a00330deae8c004a8d491b46bf7adb84241 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -246,32 +246,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).Convert( convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } Status HandleBitcastConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(Literal result, parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( convert->shape().element_type())); - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } + parent_->evaluated_[convert] = std::move(result); return Status::OK(); } @@ -978,10 +967,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span out_index) { + result.Populate([&](absl::Span out_index) { std::vector from_index(out_index.begin(), out_index.end()); for (const int64 dim : reverse_dimensions) { from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; @@ -1021,9 +1010,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " @@ -1046,9 +1036,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](absl::Span out_index) { + rhs_literal_data, + feature_group_count](const absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1059,7 +1052,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_batch_dim = dnums.output_batch_dimension(); const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 input_z_size = + ShapeUtil::GetDimension(lhs_shape, input_z_dim); + // The size of an input feature group. + const int64 input_feature_group_size = input_z_size / feature_group_count; + + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); + // The output feature dimension is a concatenation of convolution results + // from the different groups. + const int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current output index + // belongs. + const int64 feature_group_index = + out_index[output_z_dim] / output_feature_group_size; ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1067,7 +1075,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { - for (int64 iz = 0; iz < z_size; ++iz) { + for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) { + const int64 iz = + feature_group_index * input_feature_group_size + rhs_iz; + int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1076,7 +1087,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rhs_linear_index = 0; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { @@ -1135,8 +1146,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast(result_val); }; - auto result = absl::make_unique(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + Literal result(result_shape); + TF_RETURN_IF_ERROR(result.PopulateParallel(func)); parent_->evaluated_[conv] = std::move(result); return Status::OK(); @@ -1209,9 +1220,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = absl::make_unique(dot->shape()); + Literal result(dot->shape()); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span result_index) { + result.Populate([&](absl::Span result_index) { ElementwiseT result_val = static_cast(0); for (int64 i = 0; i < result_index.size(); i++) { @@ -1258,8 +1269,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = absl::make_unique(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( + Literal result(pad->shape()); + TF_RETURN_IF_ERROR(result.Populate( [&scalar](absl::Span multi_index) { return scalar; })); const Literal& evaluated_operand = @@ -1267,7 +1278,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + std::vector target_index(ShapeUtil::Rank(result.shape()), 0); // Loop through each element of the operand, assign them to the // corresponding index of the resulting padded literal. @@ -1289,8 +1300,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return true; } } - result->Set(target_index, - evaluated_operand.Get(input_index)); + result.Set(target_index, + evaluated_operand.Get(input_index)); return true; }; @@ -1417,16 +1428,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> MapImpl(HloInstruction* map) { + StatusOr MapImpl(HloInstruction* map) { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = absl::make_unique(map->shape()); + Literal result(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { - std::vector> arg_literals; + result.Populate([&](absl::Span multi_index) { + std::vector arg_literals; arg_literals.reserve(operands.size()); // Construct scalar literal parameters to be passed to the map @@ -1441,16 +1452,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_literals.push_back(std::move(curr_val_literal)); } - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) + Literal computed_result = + embedded_evaluator.Evaluate(*computation, arg_literals) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on // the same computation. embedded_evaluator.ResetVisitStates(); - return computed_result->Get({}); + return computed_result.Get({}); })); return std::move(result); } @@ -1535,9 +1544,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess(a, b); }); - auto result_literal = absl::make_unique(keys_literal.shape()); - result_literal->PopulateR1(absl::Span(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + Literal result_literal(keys_literal.shape()); + result_literal.PopulateR1(absl::Span(result_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); return result_literal; }; @@ -1546,16 +1555,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = absl::make_unique(keys_literal.shape()); + Literal result_literal(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, keys_literal.Slice({row, 0}, {row + 1, r1_length}) - ->Reshape({r1_length})); - auto r1_result = sort_r1(*r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( - *r1_result, {0, 0}, {row, 0}, {1, r1_length})); + .Reshape({r1_length})); + auto r1_result = sort_r1(r1_slice); + TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + r1_result, {0, 0}, {row, 0}, {1, r1_length})); } parent_->evaluated_[sort] = std::move(result_literal); } @@ -1629,9 +1638,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - absl::InlinedVector, 1> results(num_args); + absl::InlinedVector results(num_args); for (int64 i = 0; i < num_args; ++i) { - results[i] = absl::make_unique(result_shape); + results[i] = Literal(result_shape); } Status eval_status; @@ -1645,7 +1654,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } for (int64 input = 0; input < num_args; ++input) { - TF_RETURN_IF_ERROR(results[input]->Populate( + TF_RETURN_IF_ERROR(results[input].Populate( [&](absl::Span multi_index) { if (!eval_status.ok()) { return init_scalars[input]; @@ -1681,8 +1690,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Evaluate computation with specified literal operands. - absl::InlinedVector, 1> - embedded_operands; + absl::InlinedVector embedded_operands; for (ReturnT value : result_values) { embedded_operands.push_back( LiteralUtil::CreateR0(value)); @@ -1695,11 +1703,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_operands.size()); std::transform(embedded_operands.begin(), embedded_operands.end(), embedded_operands_ptrs.begin(), - [](const std::unique_ptr& ptr) { - return ptr.get(); - }); + [](Literal& literal) { return &literal; }); - TF_ASSIGN_OR_RETURN(std::unique_ptr computed_result, + TF_ASSIGN_OR_RETURN(Literal computed_result, embedded_evaluator.Evaluate( *function, embedded_operands_ptrs)); // Clear visit states so that we can use the evaluator again on @@ -1707,10 +1713,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. if (!has_tuple_output) { - result_values[0] = computed_result->Get({}); + result_values[0] = computed_result.Get({}); } else { for (int64 i = 0; i < num_args; ++i) { - result_values[i] = computed_result->Get( + result_values[i] = computed_result.Get( /*multi_index=*/{}, /*shape_index=*/{i}); } } @@ -1726,9 +1732,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (!has_tuple_output) { parent_->evaluated_[reduce] = std::move(results[0]); } else { - auto tuple_result = absl::make_unique(reduce->shape()); + Literal tuple_result(reduce->shape()); for (int64 i = 0; i < num_args; ++i) { - TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i})); + TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i})); } parent_->evaluated_[reduce] = std::move(tuple_result); } @@ -1759,10 +1765,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get({}); - auto result = absl::make_unique(select_and_scatter->shape()); + Literal result(select_and_scatter->shape()); // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( + TF_RETURN_IF_ERROR(result.Populate( [&](absl::Span output_index) { return init_scalar; })); std::vector window_dimension_sizes; @@ -1812,15 +1818,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - curr_val_literal->Set({}, curr_val); - selected_val_literal->Set({}, *selected_val); - std::unique_ptr computed_result = + curr_val_literal.Set({}, curr_val); + selected_val_literal.Set({}, *selected_val); + Literal computed_result = embedded_evaluator .Evaluate( - *select, - {selected_val_literal.get(), curr_val_literal.get()}) + *select, {&selected_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); + bool selected = !computed_result.Get({}); if (selected) { selected_val = curr_val; selected_index = operand_index; @@ -1834,16 +1839,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (std::equal(operand_index.begin(), operand_index.end(), selected_index->begin())) { auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - source_literal_scatter->Set({}, source); - scattered_literal->Set({}, scattered); - std::unique_ptr computed_result = + auto scattered = result.Get(operand_index); + source_literal_scatter.Set({}, source); + scattered_literal.Set({}, scattered); + Literal computed_result = embedded_evaluator - .Evaluate(*scatter, - {source_literal_scatter.get(), - scattered_literal.get()}) + .Evaluate( + *scatter, + {&source_literal_scatter, &scattered_literal}) .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); + result.Set(operand_index, computed_result.Get({})); // Clear visit states so that the we can use the evaluator again // on the same computation. embedded_evaluator.ResetVisitStates(); @@ -1894,10 +1899,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = absl::make_unique(reduce_window->shape()); + Literal result(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span output_index) { + result.Populate([&](absl::Span output_index) { ReturnT result_val = init_scalar; std::fill(window_index.begin(), window_index.end(), 0); @@ -1913,18 +1918,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { LiteralUtil::CreateR0(curr_val); const auto result_val_literal = LiteralUtil::CreateR0(result_val); - std::unique_ptr computed_result = + Literal computed_result = embedded_evaluator .Evaluate( - *function, - {result_val_literal.get(), curr_val_literal.get()}) + *function, {&result_val_literal, &curr_val_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again // on the same computation. embedded_evaluator.ResetVisitStates(); - result_val = computed_result->Get({}); + result_val = computed_result.Get({}); }); return result_val; @@ -1939,7 +1943,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // literal (if there is one) to `reshaped_indices`. StatusOr> ReshapedScatterIndices( int64 index_vector_dim, const Literal& indices, - std::unique_ptr* reshaped_indices) { + Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { return std::cref(indices); } @@ -1948,7 +1952,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { indices.shape().dimensions().end()); new_shape.push_back(1); TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape)); - return std::cref(**reshaped_indices); + return std::cref(*reshaped_indices); } // Returns an ShapeUtil::IndexIterationSpace that iterates over the update @@ -2208,7 +2212,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { scatter->scatter_dimension_numbers(); const Literal& operand = parent_->GetEvaluatedLiteralFor(scatter->operand(0)); - std::unique_ptr reshaped_scatter_indices; + Literal reshaped_scatter_indices; TF_ASSIGN_OR_RETURN(const Literal& scatter_indices, ReshapedScatterIndices(dim_numbers.index_vector_dim(), parent_->GetEvaluatedLiteralFor( @@ -2238,7 +2242,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Initialize the result with the operand. This makes it easier to handle // the updates even when the indices are repeated. - std::unique_ptr result = operand.CloneToUnique(); + Literal result = operand.Clone(); HloEvaluator embedded_evaluator; auto scatter_inner_loop_body = [&](absl::Span update_window_index, @@ -2277,19 +2281,19 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } auto result_value_literal = - LiteralUtil::CreateR0(result->Get(input_index)); + LiteralUtil::CreateR0(result.Get(input_index)); auto update_value_literal = LiteralUtil::CreateR0(updates.Get(update_index)); - std::unique_ptr updated_result = + Literal updated_result = embedded_evaluator .Evaluate( *scatter->to_apply(), - {result_value_literal.get(), update_value_literal.get()}) + {&result_value_literal, &update_value_literal}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again on the // same computation. embedded_evaluator.ResetVisitStates(); - result->Set(input_index, updated_result->Get({})); + result.Set(input_index, updated_result.Get({})); return true; }; @@ -2337,9 +2341,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return operand_literal.Get(operand_index); }; - auto result = LiteralUtil::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); + Literal result(shape); + TF_RETURN_IF_ERROR(result.Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); } @@ -2553,7 +2556,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { if (ShapeUtil::Rank(iota->shape()) > 1) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[iota], - result->Broadcast(iota->shape(), {iota->iota_dimension()})); + result.Broadcast(iota->shape(), {iota->iota_dimension()})); } else { TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); parent_->evaluated_[iota] = std::move(result); @@ -2623,9 +2626,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { + StatusOr DynamicSlice(const Literal& operand_literal, + const Literal& start_indices_literal, + const Shape& result_shape) { auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -2638,9 +2641,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector operand_indices(start.size()); - auto result = absl::make_unique(result_shape); + Literal result(result_shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { CHECK_GE(multi_index[i] + start[i], 0); operand_indices[i] = multi_index[i] + start[i]; @@ -2654,12 +2657,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); + StatusOr DynamicUpdateSlice(const Literal& operand_literal, + const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.Clone(); auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); + const auto rank = ShapeUtil::Rank(result.shape()); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); // Clamp the update start indices so the slice is in-bounds w.r.t the @@ -2667,15 +2670,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { for (int64 i = 0; i < rank; ++i) { start[i] = std::min( std::max(0, start[i]), - result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + result.shape().dimensions(i) - update_literal.shape().dimensions(i)); } std::vector result_index(rank, 0); auto func = [&](absl::Span update_index) { std::transform(update_index.begin(), update_index.end(), start.begin(), result_index.begin(), std::plus()); - result->Set(result_index, - update_literal.Get(update_index)); + result.Set(result_index, + update_literal.Get(update_index)); return true; }; @@ -2688,7 +2691,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } - StatusOr> ElementWiseUnaryOp( + StatusOr ElementWiseUnaryOp( HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -2701,7 +2704,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr> ElementWiseBinaryOp( + StatusOr ElementWiseBinaryOp( HloInstruction* instruction, const std::function& binary_op) { @@ -2723,10 +2726,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ConvertBinaryFunction(binary_op)( lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -2735,7 +2738,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template - StatusOr> ElementwiseTernaryOp( + StatusOr ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -2760,10 +2763,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = absl::make_unique(shape); + Literal result(shape); TF_RETURN_IF_ERROR( - result->Populate([&](absl::Span multi_index) { + result.Populate([&](absl::Span multi_index) { return ternary_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index), ehs_literal.Get(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 3041d94fa9f55b1acffc1295d07e48c967322865..287ba84b3b24d3ec6dc21d157205ebc6a987c7d7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -120,12 +120,23 @@ class NodeFilter { std::function filter_; }; +// We arbitrarily set this as the boundary between "large" and "small" +// instructions. +bool IsSmall(const HloInstruction* instr) { + if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) || + ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) { + return true; + } + return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; +} + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, kBrown, kDarkBlue, kDarkGreen, + kDarkOrange, kDarkRed, kGray, kGreen, @@ -158,6 +169,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) { return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; case kDarkGreen: return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkOrange: + // This is more of a "medium" orange, made to look close to kOrange; + // there's probably room for a darker weight if desired. + return NodeColors{"filled", "#ffb74d", "#c88719", "black"}; case kDarkRed: return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; case kGray: @@ -454,9 +469,8 @@ stylesheet=< string graph_label = StrCat(label_, "
Computation ", computation_->name()); if (computation_->IsFusionComputation()) { - StrAppend(&graph_label, - StrCat(" (in fusion instruction ", - computation_->FusionInstruction()->name(), ")")); + StrAppend(&graph_label, " (in fusion instruction ", + computation_->FusionInstruction()->name(), ")"); } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); @@ -893,7 +907,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { sharding_colors_.emplace(instr->sharding(), color); return color; } - const auto kParameterColor = kOrange; + + // Choose different weights of orange for small vs large parameters. This + // distinction is often important, especially in fusion nodes. + auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange; // Special case: If this instruction has a parameter merged into it, paint it // the same color as a parameter. Unless the merged-in parameter is a @@ -905,7 +922,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { ShouldMergeIntoUsers(operand) && TryGetFusionParameterConstant(operand) == nullptr; })) { - return kParameterColor; + return parameter_color; } // Pick different colors or shapes for instructions which are particularly @@ -1015,7 +1032,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kParameterColor; + return parameter_color; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1160,20 +1177,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return StrJoin(lines, "
"); } -// Gets the total number of array elements in the given shape. For tuples, this -// is the sum of all the sizes of all of the array elements recursively in the -// tuple. -static int64 TotalElementsInShape(const Shape& shape) { - int64 elems = 0; - ShapeUtil::ForEachSubshape( - shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { - elems += ShapeUtil::ElementsIn(subshape); - } - }); - return elems; -} - void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1196,14 +1199,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } // We print "small" arrays using a hollow arrowhead and "large" arrays using - // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" - // means. - bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - + // a filled arrowhead. constexpr char kEdgeFmt[] = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), + (IsSmall(from) ? "empty" : "normal"), from->name(), to->name(), edge_label)); }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6d13f85cbbca2ae4b2a794ca5de975fe21e8212e..e905f2983a43189eeb06824cf3078c235ab07925 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -250,7 +250,7 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); - instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } case HloOpcode::kFusion: { @@ -341,17 +341,21 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs); break; } - case HloOpcode::kConvolution: + case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), proto.window(), - proto.convolution_dimension_numbers(), - std::max(static_cast(proto.feature_group_count()), 1LL)); + proto.shape(), operands(0), operands(1), + std::max(proto.feature_group_count(), 1), proto.window(), + proto.convolution_dimension_numbers(), precision_config); break; + } case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) << "ReduceWindow instruction should have 2 operands but sees " @@ -447,6 +451,28 @@ StatusOr> HloInstruction::CreateFromProto( << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; + case HloOpcode::kDot: { + TF_RET_CHECK(proto.has_dot_dimension_numbers()) + << "Dot instruction should have dot_dimension_numbers."; + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Dot instruction should have 2 operands but sees " + << proto.operand_ids_size(); + PrecisionConfig precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfig::DEFAULT); + instruction = absl::make_unique( + proto.shape(), operands(0), operands(1), + proto.dot_dimension_numbers(), precision_config); + break; + } + case HloOpcode::kDomain: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Domain instruction should have 1 operands but sees " + << proto.operand_ids_size(); + instruction = absl::make_unique( + proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, + /*user_side_metadata=*/nullptr); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -468,6 +494,9 @@ StatusOr> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); break; } } @@ -476,12 +505,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - instruction->precision_config_ = proto.precision_config(); - - if (proto.has_dot_dimension_numbers()) { - instruction->dot_dimension_numbers_ = - absl::make_unique(proto.dot_dimension_numbers()); - } + instruction->unique_id_ = proto.id(); if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -504,7 +528,7 @@ StatusOr> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr HloInstruction::CreateConstant( - std::unique_ptr literal) { + Literal literal) { return absl::make_unique(std::move(literal)); } @@ -552,7 +576,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: - case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -584,7 +607,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kComplex: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kGe: case HloOpcode::kGt: @@ -643,10 +665,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { return absl::make_unique( - shape, lhs, rhs, window, dimension_numbers, feature_group_count); + shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -658,30 +682,10 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique(dimension_numbers); - return instruction; -} - -/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { - CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); - CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = - absl::make_unique(); - instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); - instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); - return instruction; + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) { + return absl::make_unique( + shape, lhs, rhs, dimension_numbers, precision_config); } /* static */ std::unique_ptr @@ -1057,7 +1061,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); - derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1142,12 +1145,9 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, std::unique_ptr operand_side_metadata, std::unique_ptr user_side_metadata) { - auto instruction = - absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); - instruction->operand_side_metadata_ = std::move(operand_side_metadata); - instruction->user_side_metadata_ = std::move(user_side_metadata); - instruction->AppendOperand(operand); - return instruction; + return absl::make_unique( + shape, operand, std::move(operand_side_metadata), + std::move(user_side_metadata)); } std::unique_ptr HloInstruction::CloneWithNewOperands( @@ -1203,6 +1203,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kGather: case HloOpcode::kScatter: case HloOpcode::kIota: + case HloOpcode::kDot: + case HloOpcode::kDomain: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1275,11 +1277,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kDot: - CHECK_EQ(new_operands.size(), 2); - clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); @@ -1304,12 +1301,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kDomain: - CHECK_EQ(new_operands.size(), 1); - clone = - CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), - user_side_metadata_->Clone()); - break; case HloOpcode::kAfterAll: if (new_operands.empty()) { clone = CreateToken(); @@ -1605,11 +1596,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kAfterAll: return false; - // Check dot dimension numbers. - case HloOpcode::kDot: - return protobuf_util::ProtobufEquals(dot_dimension_numbers(), - other.dot_dimension_numbers()); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1625,10 +1611,6 @@ bool HloInstruction::IdenticalSlowPath( return false; } - case HloOpcode::kDomain: - return operand_side_metadata().Matches(other.operand_side_metadata()) && - user_side_metadata().Matches(other.user_side_metadata()); - // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. case HloOpcode::kBatchNormTraining: @@ -1668,6 +1650,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kScatter: + case HloOpcode::kDot: + case HloOpcode::kDomain: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2037,15 +2021,6 @@ std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { std::vector extra = ExtraAttributesToStringImpl(options); - if (dot_dimension_numbers_ != nullptr) { - extra.push_back(DotDimensionNumbersToString()); - } - - string precision_config_string = PrecisionConfigToString(); - if (!precision_config_string.empty()) { - extra.push_back(precision_config_string); - } - if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2122,7 +2097,7 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_predecessors_.empty()) { + if (options.print_control_dependencies() && !control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { @@ -2131,11 +2106,6 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { - extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", user_side_metadata_->ToString(), - ", exit=", operand_side_metadata_->ToString(), "}")); - } return extra; } @@ -2167,17 +2137,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - *proto.mutable_precision_config() = precision_config_; if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - if (dot_dimension_numbers_ != nullptr) { - *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; - } - if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } @@ -2871,8 +2836,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } -string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); +string PrecisionToString(const PrecisionConfig::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2904,31 +2869,6 @@ string ConvolutionDimensionNumbersToString( StrJoin(output_dims, "")); } -string HloInstruction::DotDimensionNumbersToString() const { - std::vector result; - if (dot_dimension_numbers_ == nullptr) { - return ""; - } - const DotDimensionNumbers& dnums = *dot_dimension_numbers_; - if (!dnums.lhs_batch_dimensions().empty()) { - result.push_back(StrCat("lhs_batch_dims={", - StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("lhs_contracting_dims={", - StrJoin(dnums.lhs_contracting_dimensions(), ","), - "}")); - - if (!dnums.rhs_batch_dimensions().empty()) { - result.push_back(StrCat("rhs_batch_dims={", - StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); - } - result.push_back(StrCat("rhs_contracting_dims={", - StrJoin(dnums.rhs_contracting_dimensions(), ","), - "}")); - - return StrJoin(result, ", "); -} - StatusOr StringToRandomDistribution(const string& name) { static std::unordered_map* map = [] { static auto* map = new std::unordered_map; @@ -2947,31 +2887,13 @@ StatusOr StringToRandomDistribution(const string& name) { return found->second; } -string HloInstruction::PrecisionConfigToString() const { - if (precision_config_.operand_precision().empty()) { - return ""; - } - return StrCat( - "operand_precision={", - StrJoin(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend(out, PrecisionToString( - static_cast( - precision))); - }), - "}"); -} - -StatusOr StringToPrecision( - const string& name) { - static std::unordered_map* map = [] { +StatusOr StringToPrecision(const string& name) { + static std::unordered_map* map = [] { static auto* map = - new std::unordered_map; - for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { - if (PrecisionConfigProto::Precision_IsValid(i)) { - auto value = static_cast(i); + new std::unordered_map; + for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) { + if (PrecisionConfig::Precision_IsValid(i)) { + auto value = static_cast(i); (*map)[PrecisionToString(value)] = value; } } @@ -3024,6 +2946,16 @@ Status HloInstruction::set_backend_config( return ret; } +const PrecisionConfig& HloInstruction::precision_config() const { + if (auto* convolution = DynCast(this)) { + return convolution->precision_config(); + } + if (auto* dot = DynCast(this)) { + return dot->precision_config(); + } + LOG(FATAL) << "Unimplemented method."; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3328,4 +3260,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() return Cast(this)->scatter_dimension_numbers(); } +const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const { + return Cast(this)->dot_dimension_numbers(); +} + +const DomainMetadata& HloInstruction::operand_side_metadata() const { + return Cast(this)->operand_side_metadata(); +} + +const DomainMetadata& HloInstruction::user_side_metadata() const { + return Cast(this)->user_side_metadata(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cca134e8b45f89a1c395c791029ee68eeec3c8f0..4f6cac1396c16beb5cebf909032dead711d77a61 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -82,6 +82,7 @@ class HloPrintOptions { print_operand_shape_(true), print_program_shape_(true), print_percent_(true), + print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), is_in_nested_computation_(false) {} @@ -94,7 +95,8 @@ class HloPrintOptions { .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) - .set_print_percent(false); + .set_print_percent(false) + .set_print_control_dependencies(false); } // Options to produce the canonical string representing an isomorphic @@ -108,6 +110,7 @@ class HloPrintOptions { .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) + .set_print_control_dependencies(false) .set_canonicalize_instruction_names(true); } @@ -153,6 +156,12 @@ class HloPrintOptions { return *this; } + // If true, control dependencies will be printed. + HloPrintOptions& set_print_control_dependencies(bool value) { + print_control_dependencies_ = value; + return *this; + } + // If true, only a part of operands will be printed out, and their names will // be omitted (note that in this case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { @@ -190,6 +199,9 @@ class HloPrintOptions { bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool print_control_dependencies() const { + return print_control_dependencies_; + } bool canonicalize_instruction_names() const { return canonicalize_instruction_names_; } @@ -205,6 +217,7 @@ class HloPrintOptions { bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool print_control_dependencies_; bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; @@ -346,8 +359,7 @@ class HloInstruction { const string& name); // Creates a literal constant instruction. - static std::unique_ptr CreateConstant( - std::unique_ptr literal); + static std::unique_ptr CreateConstant(Literal literal); // Creates an Iota instruction. static std::unique_ptr CreateIota(const Shape& shape, @@ -405,9 +417,9 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( @@ -418,13 +430,8 @@ class HloInstruction { // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers); - - // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 - // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS - // and the RHS must be of rank 2. - static std::unique_ptr CreateCanonicalDot( - const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to @@ -865,11 +872,6 @@ class HloInstruction { return false; } - if (!absl::c_equal(precision_config_.operand_precision(), - other.precision_config_.operand_precision())) { - return false; - } - return IdenticalSlowPath(other, eq_computations); } @@ -1084,15 +1086,6 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } - // Retrieves the operand side metadata of a kDomain instruction. - const DomainMetadata& operand_side_metadata() const { - return *operand_side_metadata_; - } - // Retrieves the user side metadata of a kDomain instruction. - const DomainMetadata& user_side_metadata() const { - return *user_side_metadata_; - } - // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1100,18 +1093,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Returns data on the dimension numbers used for a dot operation. - const DotDimensionNumbers& dot_dimension_numbers() const { - CHECK(dot_dimension_numbers_ != nullptr); - return *dot_dimension_numbers_; - } - - // Returns the dump string of the dot dimension numbers. - string DotDimensionNumbersToString() const; - - // Returns the dump string of the precision configuration. - string PrecisionConfigToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1261,12 +1242,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfigProto& precision_config() const { - return precision_config_; - } - void set_precision_config(const PrecisionConfigProto& precision_config) { - precision_config_ = precision_config; - } + // Precondition: opcode must be kConvolution or kDot. + const PrecisionConfig& precision_config() const; // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1509,6 +1486,15 @@ class HloInstruction { // Delegates to HloScatterInstruction::scatter_dimension_numbers(). const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Delegates to HloDotInstruction::dot_dimension_numbers(). + const DotDimensionNumbers& dot_dimension_numbers() const; + + // Delegates to HloDomainInstruction::operand_side_metadata(). + const DomainMetadata& operand_side_metadata() const; + + // Delegates to HloDomainInstruction::user_side_metadata(). + const DomainMetadata& user_side_metadata() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1648,22 +1634,12 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Describes the dimension numbers used for a dot. - std::unique_ptr dot_dimension_numbers_; - - // Used to tag kCopy instructions that are eligible for copy elision. - bool copy_elision_allowed_ = true; - // The sharding, if one exists. // Uses std::shared_ptr to allow reuse of the same sharding object between // HloInstructions and other components as HloSharding can be very large for // many element tuples. std::shared_ptr sharding_; - // Fields used by the kDomain instruction. - std::unique_ptr operand_side_metadata_; - std::unique_ptr user_side_metadata_; - // Computations called by this instruction. std::vector called_computations_; @@ -1677,10 +1653,6 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; - // Information used to communicate to the implementation about the algorithm - // used to produce results. See the documentation on precision_config(). - PrecisionConfigProto precision_config_; - // String identifier for instruction. string name_; @@ -1703,12 +1675,12 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); -string PrecisionToString(const PrecisionConfigProto::Precision& precision); +string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); -StatusOr StringToPrecision(const string& name); +StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b0e940a656ee2f54781b927fdca367a83056c6..c1b7c3832b44b5d65b715dffa5211a5c92e17953 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1147,8 +1147,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1188,8 +1188,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1239,8 +1239,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2))); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( @@ -1320,8 +1320,8 @@ TEST_F(HloInstructionTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().set_print_metadata(false); @@ -1485,8 +1485,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().Canonical(); @@ -1527,8 +1527,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1583,8 +1583,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1752,9 +1752,9 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { auto* conv = module->entry_computation()->root_instruction(); auto clone = conv->Clone(); - EXPECT_THAT(clone->precision_config().operand_precision(), - ::testing::ElementsAre(PrecisionConfigProto::HIGH, - PrecisionConfigProto::DEFAULT)); + EXPECT_THAT( + clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e46afa764f519c9f7b6e3e9a8a37c84bd173b9a2..e92882c22a6ef1dd43440d3c94c7d233c9a4fb5d 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, return instruction->IsElementwiseOnOperand(operand_index); }); } + +string PrecisionConfigToString(const PrecisionConfig& precision_config) { + if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) { + return static_cast(precision) == + PrecisionConfig::DEFAULT; + })) { + return ""; + } + + return StrCat( + "operand_precision={", + StrJoin( + precision_config.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast(precision))); + }), + "}"); +} } // namespace HloBatchNormInstruction::HloBatchNormInstruction( @@ -824,8 +845,8 @@ std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } -HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) - : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), +HloConstantInstruction::HloConstantInstruction(Literal literal) + : HloInstruction(HloOpcode::kConstant, literal.shape()), literal_(std::move(literal)) {} HloConstantInstruction::HloConstantInstruction(const Shape& shape) @@ -833,7 +854,7 @@ HloConstantInstruction::HloConstantInstruction(const Shape& shape) HloInstructionProto HloConstantInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - if (literal_ != nullptr) { + if (literal_.has_value()) { *proto.mutable_literal() = literal_->ToProto(); } return proto; @@ -855,7 +876,7 @@ void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, if (!mutable_array_subshape->has_layout() || !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); + *literal_ = literal_->Relayout(new_layout, shape_index); *mutable_array_subshape->mutable_layout() = new_layout; } } @@ -872,7 +893,8 @@ std::unique_ptr HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { - return absl::make_unique(literal_->CloneToUnique()); + CHECK(literal_.has_value()); + return absl::make_unique(literal_->Clone()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -880,7 +902,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( CanonicalNameMap* canonical_name_map) const { string operands; // For constants, show the actual value in place of an empty operand list. - if (literal_ != nullptr && + if (literal_.has_value() && ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple @@ -915,7 +937,7 @@ HloTraceInstruction::HloTraceInstruction(const string& tag, HloInstructionProto HloTraceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_literal() = literal_->ToProto(); + *proto.mutable_literal() = literal_.ToProto(); return proto; } @@ -1628,12 +1650,14 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), + feature_group_count_(feature_group_count), window_(window), convolution_dimension_numbers_(dimension_numbers), - feature_group_count_(feature_group_count) { + precision_config_(precision_config) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1661,6 +1685,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = convolution_dimension_numbers_; proto.set_feature_group_count(feature_group_count_); + *proto.mutable_precision_config() = precision_config_; return proto; } @@ -1672,7 +1697,15 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); - extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; } @@ -1688,7 +1721,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(window(), casted_other.window()) && protobuf_util::ProtobufEquals( convolution_dimension_numbers(), - casted_other.convolution_dimension_numbers()); + casted_other.convolution_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); } std::unique_ptr @@ -1697,8 +1732,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], window(), - convolution_dimension_numbers_, feature_group_count_); + shape, new_operands[0], new_operands[1], feature_group_count_, window(), + convolution_dimension_numbers_, precision_config_); } HloReduceWindowInstruction::HloReduceWindowInstruction( @@ -2157,4 +2192,113 @@ std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( return absl::make_unique(shape, iota_dimension()); } +HloDotInstruction::HloDotInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config) + : HloInstruction(HloOpcode::kDot, shape), + dot_dimension_numbers_(dimension_numbers), + precision_config_(precision_config) { + AppendOperand(lhs); + AppendOperand(rhs); +} + +HloInstructionProto HloDotInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_; + *proto.mutable_precision_config() = precision_config_; + return proto; +} + +std::vector HloDotInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra = {DotDimensionNumbersToString()}; + + string precision_config_string = PrecisionConfigToString(precision_config_); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + return extra; +} + +bool HloDotInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + casted_other.dot_dimension_numbers()) && + protobuf_util::ProtobufEquals(precision_config(), + casted_other.precision_config()); +} + +std::unique_ptr HloDotInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return absl::make_unique( + shape, new_operands[0], new_operands[1], dot_dimension_numbers_, + precision_config_); +} + +string HloDotInstruction::DotDimensionNumbersToString() const { + std::vector result; + const DotDimensionNumbers& dnums = dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); + + return StrJoin(result, ", "); +} + +HloDomainInstruction::HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) + : HloInstruction(HloOpcode::kDomain, shape), + operand_side_metadata_(std::move(operand_side_metadata)), + user_side_metadata_(std::move(user_side_metadata)) { + AppendOperand(operand); +} + +std::vector HloDomainInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")}; + } + return {}; +} + +bool HloDomainInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return operand_side_metadata().Matches( + casted_other.operand_side_metadata()) && + user_side_metadata().Matches(casted_other.user_side_metadata()); +} + +std::unique_ptr HloDomainInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 323038357993c4e9b99d1527aa8f593ada92f1c8..2d7bc83855e761ed313d831a1252a54130910bbe 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -580,13 +580,13 @@ class HloSliceInstruction : public HloInstruction { class HloConstantInstruction : public HloInstruction { public: - explicit HloConstantInstruction(std::unique_ptr literal); + explicit HloConstantInstruction(Literal literal); // Used when the literal is too large and dropped. explicit HloConstantInstruction(const Shape& shape); // Returns the literal associated with this instruction. const Literal& literal() const { return *literal_; } // Returns whether there is literal associated with this instruction. - bool HasLiteral() const { return literal_ != nullptr; } + bool HasLiteral() const { return literal_.has_value(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -610,15 +610,14 @@ class HloConstantInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + absl::optional literal_; }; class HloTraceInstruction : public HloInstruction { public: explicit HloTraceInstruction(const string& tag, HloInstruction* operand); // Returns a tag to be used in tracing. - string TracingTag() const { return literal_->GetR1U8AsString(); } + string TracingTag() const { return literal_.GetR1U8AsString(); } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -631,8 +630,7 @@ class HloTraceInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - // TODO(b/36360764): Remove unique_ptr wrapping. - std::unique_ptr literal_; + Literal literal_; }; class HloFusionInstruction : public HloInstruction { @@ -942,9 +940,9 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + const PrecisionConfig& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -957,6 +955,16 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + string ToCategory() const override; // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -972,12 +980,16 @@ class HloConvolutionInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - Window window_; - // Describes the dimension numbers used for a convolution. - ConvolutionDimensionNumbers convolution_dimension_numbers_; // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // Describes the window used for a convolution. + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; }; class HloReduceWindowInstruction : public HloInstruction { @@ -1270,6 +1282,85 @@ class HloIotaInstruction : public HloInstruction { const int64 iota_dimension_; }; +class HloDotInstruction : public HloInstruction { + public: + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig& precision_config); + + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + return dot_dimension_numbers_; + } + + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfig& precision_config() const { return precision_config_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + + // Describes the dimension numbers used for a dot. + DotDimensionNumbers dot_dimension_numbers_; + + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfig precision_config_; +}; + +class HloDomainInstruction : public HloInstruction { + public: + explicit HloDomainInstruction( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; +}; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc similarity index 72% rename from tensorflow/compiler/xla/service/hlo_scheduling.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index 0fc3b268c059802a3882ad5032a9fe5da28cbf23..c7ec88d450712b0831971139f165934ef5524845 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include @@ -70,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr> Run( + static StatusOr Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -229,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector CreateSchedule() { - std::vector schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationHelper( +StatusOr ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -392,7 +392,7 @@ StatusOr> ScheduleComputationHelper( } // namespace -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -443,7 +443,7 @@ StatusOr> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -463,7 +463,7 @@ StatusOr> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -473,18 +473,16 @@ StatusOr> ListMemoryScheduler( memory_by_computation); } -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -499,7 +497,7 @@ StatusOr> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -508,7 +506,7 @@ StatusOr> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -518,7 +516,7 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -545,32 +543,35 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -581,187 +582,22 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { - tensorflow::gtl::FlatMap> id_sequence; - for (const auto& computation_sequence : sequence) { - for (const HloInstruction* instruction : computation_sequence.second) { - id_sequence[computation_sequence.first].push_back( - instruction->unique_id()); - } - } - return id_sequence; -} - -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence) { - // Map from unique ID to HloInstruction pointer for instructions in the - // module. - tensorflow::gtl::FlatMap id_to_instruction; - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK( - id_to_instruction.insert({instruction->unique_id(), instruction}) - .second); - } - for (int id : id_sequence.at(computation)) { - ids_in_schedule.insert(id); - } - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // module, but not in schedule) which use X. If an instruction is not in the - // map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; - // For each computation, this is the set of newly added instructions which - // have no operands. These must be handled specially and are added to the - // beginning of the schedule. - tensorflow::gtl::FlatMap> - new_zero_operand_instructions; - for (const HloComputation* computation : nonfusion_computations) { - new_zero_operand_instructions[computation] = {}; - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - if (instruction->operands().empty()) { - new_zero_operand_instructions[computation].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - for (const HloComputation* computation : nonfusion_computations) { - std::vector old_computation_sequence = - std::move(sequence->at(computation)); - sequence->at(computation).clear(); - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; - for (const HloInstruction* instruction : - new_zero_operand_instructions.at(computation)) { - worklist.push(instruction); - } - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - sequence->at(computation).push_back(instruction); - std::vector* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : id_sequence.at(computation)) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. - continue; - } - const HloInstruction* instruction = it->second; - worklist.push(instruction); - schedule_worklist(); - } - } - - TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); - return Status::OK(); +HloMemoryScheduler::HloMemoryScheduler( + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) + : size_function_(size_function), algorithm_(algorithm) {} + +StatusOr HloMemoryScheduler::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, size_function_, algorithm_)); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + return true; } -Status VerifySchedule( - const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence) { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(2, module.ToString()); - VLOG(2) << sequence; - - // Verify the set of computations in the sequence is exactly the set of - // computations in the module. - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); - tensorflow::gtl::FlatSet computations_in_module( - module.computations().begin(), module.computations().end()); - for (const auto& computation_sequence : sequence) { - TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; - int pos = 0; - for (const HloInstruction* instruction : sequence.at(computation)) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); +StatusOr HloDescheduler::Run(HloModule* module) { + bool changed = module->has_schedule(); + module->clear_schedule(); + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h similarity index 53% rename from tensorflow/compiler/xla/service/hlo_scheduling.h rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.h index d06b8d9a5cdef82380bd68ae0991a3957db80f48..5e02868ebadaf06458f81e4f10ac04f882421ec8 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #include #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,14 +34,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function>( +typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -47,7 +49,7 @@ StatusOr> ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -55,7 +57,7 @@ StatusOr> DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -65,63 +67,57 @@ StatusOr> PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation); -// Returns an HloModuleSequence which seeks to minimize the memory required for +// Returns an HloSchedule which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); -// Transforms the given schedule such that it is (again) a valid schedule for -// the module. This is used to update a schedule after the HLO module has been -// transformed in some way. In general, the only transformations to the module -// for which a schedule can be updated is the addition or removal of -// instructions to/from the module. Updating the schedule after new dependencies -// between existing instructions in the module is not supported and may result -// in an error status returned. -// -// Instructions in the module which also exist in the given schedule will remain -// in the same order in the updated schedule. Instructions which exist in the -// module but not in the given schedule will be placed as early as possible in -// the updated schedule. -// -// 'id_sequence' is a mirror of the given schedule 'sequence' but with -// HloInstruction ids rather than HloInstruction pointers. This should be -// constructed using ComputeIdSchedule below after the schedule is constructed -// but before the HLO module is transformed. -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence); - -// Constructs a copy of the given schedule but with HloInstruction unique ids -// rather than HloInstruction pointers. This is necessary for updating a -// schedule as HloInstruction points in the schedule may become invalid if -// instructions are removed from the module. Used by UpdateSchedule above.. -// TODO(b/113175018): Remove this function when HLO schedule is its own class. -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); - -// Verifies that the given schedule is valid for the given module. Specifically, -// the schedule contains exactly the instructions in the module and every -// dependency in the module is satisfied in the schedule. -Status VerifySchedule(const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence); +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloPassInterface { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); + ~HloMemoryScheduler() override = default; + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + StatusOr Run(HloModule* module) override; + + private: + LogicalBuffer::SizeFunction size_function_; + MemorySchedulerAlgorithm algorithm_; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModudle::has_schedule will return false. +class HloDescheduler : public HloPassInterface { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + StatusOr Run(HloModule* module) override; +}; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc similarity index 56% rename from tensorflow/compiler/xla/service/hlo_scheduling_test.cc rename to tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index d49d09d459758840ce0f9f0b05e3c033da3337f8..1b9e9bfc77c3ba91e5b878f4aa42d26d8267a49a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -66,21 +67,34 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); + HloMemoryScheduler scheduler([](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get())); + EXPECT_TRUE(changed); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + module->schedule().sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(module->schedule()); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); + + // Clear the schedule using the descheduling pass. + HloDescheduler descheduler; + EXPECT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed, + descheduler.Run(module.get())); + EXPECT_TRUE(descheduler_changed); + EXPECT_FALSE(module->has_schedule()); } TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) { @@ -108,28 +122,26 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { + for (const HloInstruction* instruction : sequence) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -220,13 +232,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -243,13 +255,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. The output buffer is aliased, // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -281,19 +293,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -332,18 +343,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -391,12 +402,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); tensorflow::gtl::FlatMap memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -406,262 +417,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. Cond is the largest one. // The output buffer of the while is aliased. EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } -TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - std::vector entry_schedule = sequence.begin()->second; - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(entry_schedule, sequence.begin()->second); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), - hlo) != sequence.at(entry).end(); - }; - - EXPECT_EQ(sequence.at(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 4); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 3); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 2); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(body).size(), 7); - EXPECT_EQ(sequence.at(cond).size(), 4); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(body).size(), 1); - EXPECT_EQ(sequence.at(cond).size(), 5); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3a1bc4e328b89d75efde7e7afeb0e52ceed4d8f9..b3949f3a6d7176950c61cafb0830d1175f17758d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,9 +51,16 @@ StatusOr HloModule::LaunderConstInstructionFromModule( return const_cast(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names) { + bool uniquify_identifiers) { if (is_entry) { CHECK_EQ(nullptr, entry_computation_); entry_computation_ = computation.get(); @@ -65,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal( } } - if (uniquify_names) { + if (uniquify_identifiers) { computation->UniquifyName(&computation_name_uniquer_); for (auto* instruction : computation->instructions()) { instruction->UniquifyName(&instruction_name_uniquer_); } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { + instruction->SetUniqueId(NewUniqueInstructionId()); + } + // Set unique id to this computation. + CHECK_NE(computation->root_instruction()->unique_id(), -1) + << "Root has no valid id: " << computation->ToString(); + computation->SetUniqueId(computation->root_instruction()->unique_id()); } else { // Don't uniquify the names of the computation or instruction, but we must // run the names through the uniquifiers to prevent future name collisions - // for computations and instructions created later. + // for computations and instructions created later. Also, set the + // next_unique_id_ to the one greater than the max unique id of any + // instruction (or the computation) to avoid ID collisions. computation_name_uniquer_.GetUniqueName(computation->name()); for (auto* instruction : computation->instructions()) { instruction_name_uniquer_.GetUniqueName(instruction->name()); + next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1); + } + if (next_unique_id_ < computation->unique_id() + 1) { + next_unique_id_ = computation->unique_id() + 1; } } - // Pick unique IDs for each instruction. - for (auto* instruction : computation->instructions()) { - instruction->SetUniqueId(NewUniqueInstructionId()); - } - // Set unique id to this computation. - CHECK_NE(computation->root_instruction()->unique_id(), -1) - << "Root has no valid id: " << computation->ToString(); - computation->SetUniqueId(computation->root_instruction()->unique_id()); - computation->set_parent(this); computations_.push_back(std::move(computation)); return computations_.back().get(); @@ -97,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/true, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -114,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { return AddComputationInternal(std::move(computation), /*is_entry=*/false, - /*uniquify_names=*/true); + /*uniquify_identifiers=*/true); } void HloModule::ReplaceComputations( @@ -198,12 +212,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -221,12 +246,18 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } /* static */ StatusOr> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config) { + VLOG(2) << "CreateFromProto()"; + XLA_VLOG_LINES(2, proto.DebugString()); + // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -290,25 +321,42 @@ StatusOr> HloModule::CreateFromProto( // Don't uniquify names because we want names to be stable across // serialization and deserialization. module->AddComputationInternal(std::move(computation), is_entry, - /*uniquify_names=*/false); + /*uniquify_identifiers=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); - // Because we didn't uniquify the names, double-check that the instruction and - // computation names are unique from the proto. + // Because we didn't uniquify the names or the ids, double-check that the + // instruction and computation names and ids are unique from the proto. tensorflow::gtl::FlatSet computation_names; tensorflow::gtl::FlatSet instruction_names; + tensorflow::gtl::FlatSet computation_ids; + tensorflow::gtl::FlatSet instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); computation_names.insert(computation->name()); + + TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id())) + << "Computation id is not unique: " << computation->unique_id(); + computation_ids.insert(computation->unique_id()); for (HloInstruction* instruction : computation->instructions()) { TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) << "Instruction name is not unique: " << instruction->name(); instruction_names.insert(instruction->name()); + + TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id())) + << "Instruction id is not unique: " << instruction->unique_id(); + instruction_ids.insert(instruction->unique_id()); } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3c3371426b7a6a054053fe6761f87c3b5a097699..3bc2d13781aa72738d695e37a02983ee82c6037d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -235,10 +237,23 @@ class HloModule { StatusOr LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, - bool uniquify_names); + bool uniquify_identifiers); const string name_; HloModuleConfig config_; @@ -262,6 +277,11 @@ class HloModule { static std::atomic next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 3f1e1cc73eeb9debe5eb6278ab192fdf9b8cc10f..68c18836eb01484b819e7b7bd26f099dcf56e7ba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -106,9 +106,6 @@ class HloModuleConfig { absl::optional entry_computation_layout_; - // Whether this is a 'host module'. - bool is_host_module_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index 98d20315e399c6b1a3979b5d11a89ef93869f4d9..f7be5cae2239e81d9aa1f5fb811a37c6086b028f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -36,23 +36,6 @@ namespace xla { namespace { -bool HasSendRecv(HloComputation* computation) { - for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kSendDone || - instruction->opcode() == HloOpcode::kRecv || - instruction->opcode() == HloOpcode::kRecvDone) { - return true; - } - for (auto* sub_computation : instruction->called_computations()) { - if (HasSendRecv(sub_computation)) { - return true; - } - } - } - return false; -} - StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { bool changed = false; for (auto* computation : module->computations()) { @@ -68,9 +51,10 @@ StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { if (!ShapeUtil::IsTuple(xla_while->shape()) || while_body_root->opcode() != HloOpcode::kTuple || - HasSendRecv(while_body_comp)) { + while_body_comp->HasSideEffect() || + xla_while->while_condition()->HasSideEffect()) { // Only run DCE on tuple-shaped while loops where body root is Tuple, - // with no send/recv instructions. + // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); continue; } diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 363862e4905fc13a4ef07aeaac255259fc6b86ba..bf66cc6bc37a5e11c9ecfc07a62ba0ea5ca11a03 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { "while.2", 1)); } +// Tests that a while whose body has outfeed operations is not DCE-ed. +TEST_F(HloModuleDceTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + +// Tests that if a loop variable is not referenced outside of a kWhile, the loop +// variable changes are not elided within the loop body, if the condition +// computation uses them. +TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) { + auto module = ParseHloString(R"( + HloModule InfiniteLoop + WhileBody { + body_param = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2) + } + WhileCondition { + cond_param = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + p0 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(p0), index=0 + constant.3 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5) + while = (s32[], s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9b56ef4643f2ca88e56456ae6c990161adb5085 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.cc @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +namespace xla { + +HloModuleGroup::HloModuleGroup(absl::string_view name, + std::unique_ptr module) + : name_(name) { + push_back(std::move(module)); +} + +HloModuleGroup::HloModuleGroup(absl::string_view name, + absl::Span> modules) + : name_(name) { + for (auto& module : modules) { + push_back(std::move(module)); + } +} + +std::vector> HloModuleGroup::ConsumeModules() { + std::vector> ret_modules = std::move(modules_); + + // Clear everything so the object state is in a known (empty) state. + modules_.clear(); + module_ptrs_.clear(); + return ret_modules; +} + +string HloModuleGroup::ToString() const { + std::ostringstream s; + s << "HloModuleGroup " << name() << "\n\n"; + for (const HloModule* module : modules()) { + s << module->ToString() << "\n"; + } + return s.str(); +} + +HloModuleGroupProto HloModuleGroup::ToProto() const { + HloModuleGroupProto proto; + proto.set_name(name()); + for (const HloModule* module : modules()) { + *proto.add_hlo_modules() = module->ToProto(); + } + return proto; +} + +/* static */ StatusOr HloModuleGroup::CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs) { + TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty"; + TF_RET_CHECK(proto.hlo_modules_size() > 0) + << "Module group must have at least one HLO module"; + TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size()); + + std::vector> modules; + for (int i = 0; i < proto.hlo_modules_size(); ++i) { + const HloModuleProto& module_proto = proto.hlo_modules(i); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module, + HloModule::CreateFromProto(module_proto, module_configs[i])); + modules.push_back(std::move(module)); + } + + return HloModuleGroup(proto.name(), absl::MakeSpan(modules)); +} + +void HloModuleGroup::push_back(std::unique_ptr module) { + modules_.push_back(std::move(module)); + module_ptrs_.push_back(modules_.back().get()); +} + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) { + out << group.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h new file mode 100644 index 0000000000000000000000000000000000000000..7338be8b9c5ed47f0ba5829cc1d603b21f00b6e0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// An abstraction representing a ordered set of HLO module built to run +// concurrently across different devices. +class HloModuleGroup { + public: + // Construct an empty module group. + explicit HloModuleGroup(absl::string_view name) : name_(name) {} + + // Construct a module group containing a single module. + HloModuleGroup(absl::string_view name, std::unique_ptr module); + + // Construct a module group containing any number of modules. + HloModuleGroup(absl::string_view name, + absl::Span> modules); + + // Returns the modules contained in the group. + const std::vector& modules() const { return module_ptrs_; } + + // Returns a module at a particular index. + HloModule& module(int index) const { return *module_ptrs_.at(index); } + + // Add a module to the back of vector of modules in the group. + void push_back(std::unique_ptr module); + + // Moves all modules from the group into the returned vector. After this + // method runs, the module group will be empty. + std::vector> ConsumeModules(); + + string name() const { return name_; } + string ToString() const; + + // Serialize the module group to/from a proto. + HloModuleGroupProto ToProto() const; + static StatusOr CreateFromProto( + const HloModuleGroupProto& proto, + absl::Span module_configs); + + private: + string name_; + + // Vector of modules as std::unique_ptrs. + std::vector> modules_; + + // Vector of modules as normal pointers. This vector is kept in sync with + // modules_ as modules are added to the group with push_back. + std::vector module_ptrs_; +}; + +std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebf790ba6f1b5f9a7d4be8a8324420dbe11793f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_group.h" + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { + +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +class HloModuleGroupTest : public HloTestBase { + protected: + HloModuleGroupTest() = default; +}; + +TEST_F(HloModuleGroupTest, SingleModule) { + const string text = R"( +HloModule simple_module + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + HloModuleGroup group(TestName(), std::move(module)); + + EXPECT_EQ(group.modules().size(), 1); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config()})); + EXPECT_EQ(group_copy.modules().size(), 1); + EXPECT_THAT( + group_copy.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + + std::vector> modules = group.ConsumeModules(); + EXPECT_EQ(modules.size(), 1); + EXPECT_EQ(group.modules().size(), 0); +} + +TEST_F(HloModuleGroupTest, MultipleModules) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + std::vector> modules; + modules.push_back(std::move(module_0)); + modules.push_back(std::move(module_1)); + HloModuleGroup group(TestName(), absl::MakeSpan(modules)); + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy, + HloModuleGroup::CreateFromProto( + group.ToProto(), {group.module(0).config(), + group.module(1).config()})); + EXPECT_EQ(group_copy.modules().size(), 2); +} + +TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) { + const string text_0 = R"( +HloModule module0 + +ENTRY %entry (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} +)"; + const string text_1 = R"( +HloModule module1 + +ENTRY %entry (a: f32[]) -> f32[] { + ROOT %a = f32[] parameter(0) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_0, + ParseHloString(text_0)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_1, + ParseHloString(text_1)); + HloModuleGroup group(TestName()); + group.push_back(std::move(module_0)); + group.push_back(std::move(module_1)); + + EXPECT_EQ(group.modules().size(), 2); + EXPECT_THAT( + group.module(0).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add())); + EXPECT_THAT(group.module(1).entry_computation()->instructions(), + ::testing::ElementsAre(op::Parameter())); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 4bc1bacd7ddd6573e75eb5e2b38b24ff5899d330..39f38b417ab0e8b54864176d8d1e0ad1a422eca6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,10 +19,13 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" - +#include "tensorflow/core/lib/core/status_test_util.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -30,6 +33,8 @@ namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -194,6 +199,153 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + +TEST_F(HloModuleTest, ProtoSerializationPreservesIds) { + // Verify that serializing then deserializing an HLO proto preserves the + // unique IDs of the instruction and module. + const string text = + R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + + // Perform various transformations on the graph: + // + // * clone the reduction function + // * replace use of reduction function with the clone. + // * add a random instruction to the entry computation. + // + // This will create instruction and computation IDs which are interesting: + // not consecutive and not densely packed. + HloComputation* entry = module->entry_computation(); + HloInstruction* root = entry->root_instruction(); + HloComputation* reduction = root->to_apply(); + HloComputation* reduction_clone = + module->AddEmbeddedComputation(reduction->Clone()); + root->set_to_apply(reduction_clone); + TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction)); + HloInstruction* negate = entry->AddInstruction( + HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root)); + entry->set_root_instruction(negate); + + // Schedule the transformed module, this verifies that the serialized schedule + // is robust against non-consecutive IDs as well (b/114712358). + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + HloMemoryScheduler scheduler(size_fn); + TF_ASSERT_OK(scheduler.Run(module.get()).status()); + ASSERT_TRUE(module->has_schedule()); + + // Serialize and deserialize and verify that the instruction and computations + // unique ids are the same. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + + // The module IDs should *not* be the same because module ids must be globally + // unique. + EXPECT_NE(module->unique_id(), module_copy->unique_id()); + + // Verify that the computations and instructions all have the same unique id. + auto computation_copy_it = module_copy->computations().begin(); + for (const HloComputation* computation_orig : module->computations()) { + const HloComputation* computation_copy = *computation_copy_it++; + EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id()) + << absl::StrFormat( + "ID of original computation %s != ID of deserialized " + "computation %s: %d != %d", + computation_orig->name(), computation_copy->name(), + computation_orig->unique_id(), computation_copy->unique_id()); + + auto instruction_copy_it = computation_copy->instructions().begin(); + for (const HloInstruction* instruction_orig : + computation_orig->instructions()) { + const HloInstruction* instruction_copy = *instruction_copy_it++; + EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id()) + << absl::StrFormat( + "ID of original instruction %s != ID of deserialized " + "instruction %s: %d != %d", + instruction_orig->name(), instruction_copy->name(), + instruction_orig->unique_id(), instruction_copy->unique_id()); + } + } + + // Verify that the next unique ID which the module would have handed out is + // greater than the unique id of any instruction. + int next_id = module_copy->NewUniqueInstructionId(); + for (const HloComputation* computation : module_copy->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_GT(next_id, instruction->unique_id()); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0581d5c40425d332d89cc92ca6c6b0b10dd8fcf1..f1dc08bafa17a2dd68a7e922d4b84658bbf2589c 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -274,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} @@ -336,15 +338,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -362,49 +373,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back( - absl::StrFormat("computation %s order:", computation->name())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back(absl::StrFormat(" %s", instruction->name())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a..b0361c3f02922bcaa14d52ad3b240701080f9b58 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -71,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. @@ -183,17 +180,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +191,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +207,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c70bff1d2a022e395652049768d6d21..00970bcda34209d33867099d0bcf3b2902d52ae8 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +377,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ea8e6a239a22335b644369a78791029c36315560..11caa89c545e8fbfad96a9ab8e448a68a565e423 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -44,6 +45,20 @@ using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -90,16 +105,13 @@ class HloParser { string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); bool ParseControlPredecessors(HloInstruction* instruction); - bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape); - bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); - bool ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape); + bool ParseLiteral(Literal* literal, const Shape& shape); + bool ParseTupleLiteral(Literal* literal, const Shape& shape); + bool ParseNonTupleLiteral(Literal* literal, const Shape& shape); + bool ParseDenseLiteral(Literal* literal, const Shape& shape); + bool ParseSparseLiteral(Literal* literal, const Shape& shape); template - bool ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape); + bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape); // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. @@ -221,7 +233,7 @@ class HloParser { bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); - bool ParsePrecisionList(std::vector* result); + bool ParsePrecisionList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -240,7 +252,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParsePrecision(PrecisionConfigProto::Precision* result); + bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -366,9 +378,25 @@ bool HloParser::ParseHloModule() { return false; } + absl::optional is_scheduled; + std::unordered_map attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + module_ = absl::make_unique(name, config_); - return ParseComputations(); + if (!ParseComputations()) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK( + module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + } + + return true; } // computations ::= (computation)+ @@ -530,10 +558,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; - optional> operand_precision; - attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, - &operand_precision}; - HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -550,7 +574,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConstant: { - std::unique_ptr literal; + Literal literal; if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || @@ -913,6 +937,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -923,9 +950,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, - feature_group_count.value())); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], + feature_group_count.value(), *window, *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -1241,11 +1276,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional custom_call_target; optional window; optional dnums; + optional feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1257,6 +1295,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { @@ -1272,6 +1313,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1296,8 +1340,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - instruction = builder->AddInstruction( - HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + PrecisionConfig precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfig::DEFAULT); + } + + instruction = builder->AddInstruction(HloInstruction::CreateDot( + shape, operands[0], operands[1], dnum, precision_config)); break; } case HloOpcode::kGather: { @@ -1414,12 +1467,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } - if (operand_precision) { - PrecisionConfigProto precision_config; - *precision_config.mutable_operand_precision() = {operand_precision->begin(), - operand_precision->end()}; - instruction->set_precision_config(precision_config); - } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1760,8 +1807,7 @@ bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -1771,8 +1817,7 @@ bool HloParser::ParseLiteral(std::unique_ptr* literal, // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return TokenError(StrCat("expects tuple constant in shape ", ShapeUtil::HumanString(shape))); @@ -1780,8 +1825,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } - std::vector> elements( - ShapeUtil::TupleElementCount(shape)); + std::vector elements(ShapeUtil::TupleElementCount(shape)); if (lexer_.GetKind() == TokKind::kRparen) { // empty @@ -1807,8 +1851,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -1817,8 +1860,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1912,7 +1954,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // TODO(congliu): bool type literals with rank >= 1 are actually // printed in a compact form instead of "true" or "false". Fix that. if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, - linear_index++, literal->get())) { + linear_index++, literal)) { return false; } lexer_.Lex(); @@ -1923,7 +1965,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { @@ -1934,7 +1976,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, loc, StrCat("expect floating point value for primitive type: ", PrimitiveType_Name(shape.element_type()))); } - if (!SetValueInLiteral(value, linear_index++, literal->get())) { + if (!SetValueInLiteral(value, linear_index++, literal)) { return false; } } else { @@ -1946,12 +1988,11 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, } // end of switch } while (nest_level > 0); - *literal = (*literal)->Relayout(shape.layout()); + *literal = literal->Relayout(shape.layout()); return true; } -bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { if (!EatShapeAndCheckCompatible(shape)) { return false; } @@ -1991,13 +2032,12 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, } template -bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, - const Shape& shape) { +bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) { std::vector index; tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = absl::make_unique(shape); + *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2071,7 +2111,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return false; } - if ((*literal)->sparse_element_count() + 1 == + if (literal->sparse_element_count() + 1 == LayoutUtil::MaxSparseElements(shape.layout())) { return Error( lexer_.GetLoc(), @@ -2079,10 +2119,10 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, ShapeUtil::HumanStringWithLayout(shape))); } - (*literal)->AppendSparseElement(index, value); + literal->AppendSparseElement(index, value); } - (*literal)->SortSparseElements(); + literal->SortSparseElements(); return true; } @@ -2397,11 +2437,11 @@ bool HloParser::ParseAttributeHelper( return ParseDomain(static_cast(attr_out_ptr)); } case AttrTy::kPrecisionList: { - std::vector result; + std::vector result; if (!ParsePrecisionList(&result)) { return false; } - static_cast>*>( + static_cast>*>( attr_out_ptr) ->emplace(result); return true; @@ -2685,9 +2725,9 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= /*empty*/ // ::= precision_val (delim precision_val)* bool HloParser::ParsePrecisionList( - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - PrecisionConfigProto::Precision item; + PrecisionConfig::Precision item; if (!ParsePrecision(&item)) { return false; } @@ -3019,7 +3059,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { +bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { VLOG(1) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 759789437c12d489ee607638e736dfd6a6e1dda1..cca50fab5444d5e23c02952d56566b643a2192a4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default} } )" @@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } )" @@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f } )" @@ -1121,18 +1123,31 @@ ENTRY Iota { )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } @@ -1775,5 +1790,107 @@ TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); } +TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { + const string text = + R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Convolution(op::Parameter(0), op::Parameter(1))); + auto* convolution = + Cast(computation->root_instruction()); + EXPECT_EQ(convolution->feature_group_count(), 1); +} + +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558d185d1e022660d9a1d23176d0d96bf..b9c0b0c4ee1957fce48641230cef6391bcc9180e 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc index 585c95972b0e01abc14543205af71b4b0c0bdf3c..d9848cee0bfa904a90aea4626c3ee62c2cbb45b6 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloVerifiedTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c9629926eae5132f683a353a430a724a66ef3d60..bd6dd79b679729adb6691ef809b19f06c6d5dd05 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -27,15 +27,14 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -962,8 +961,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -971,7 +969,8 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1145,7 +1144,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1178,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,59 +1193,12 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); - +StatusOr HloRematerialization::Run(HloModule* module) { VLOG(1) << "HloRematerialization() with memory limit of " - << HumanReadableNumBytes(memory_limit_bytes); + << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - - // First create a copy of the schedule which contains HloInstruction unique - // ids instead of HloInstruction*. This is necessary for updating the - // schedule below. - // TODO(b/113175018): Remove this when the HLO schedule is self-contained - // and can update itself. - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(*sequence); - - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - - // RemoveUnnecessaryCopies only considers interference when determining - // whether it is legal to remove a copy. However, copies in the graph may be - // necessary for other reason such as preventing a constant from being live - // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. - // TODO(b/80249101): Break copy insertion into several passes and run each - // one once in the regular HLO pipeline. - TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); - - // The passes above can add and remove copies, update the schedule to - // account for these transformations. Newly added instructions will be - // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); - - TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(module, *sequence), module)); - } - + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); // Adjust memory limit to account for the output of the entry @@ -1262,7 +1214,7 @@ StatusOr HloRematerialization::Run( }); const int64 adjusted_memory_limit_bytes = - memory_limit_bytes - module_output_size; + memory_limit_bytes_ - module_output_size; VLOG(1) << "Adjusted memory limit accounting for output (" << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); @@ -1271,12 +1223,14 @@ StatusOr HloRematerialization::Run( // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, module](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + module->schedule() + .sequence(node.computation()) + .instructions())); } return Status::OK(); }, @@ -1294,9 +1248,10 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. - TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, - adjusted_memory_limit_bytes)); + TF_ASSIGN_OR_RETURN( + bool changed, + RematerializeComputation(module->entry_computation(), &module->schedule(), + adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an // instruction are replaced with rematerializations of the instruction. @@ -1305,30 +1260,7 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector& order = sequence->at(computation); - std::vector old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(module->schedule().Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1345,33 +1277,22 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; - if (sizes != nullptr) { - sizes->before_bytes = before_peak_memory; - sizes->after_bytes = current_peak_memory; + if (sizes_ != nullptr) { + sizes_->before_bytes = before_peak_memory; + sizes_->after_bytes = current_peak_memory; } XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); - if (current_peak_memory > memory_limit_bytes) { + if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( "Can't reduce memory use below %s (%d bytes) by rematerialization; " "only reduced to %s (%d bytes)", - HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_, HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; } -/* static */ StatusOr HloRematerialization::RematerializeAndSchedule( - const HloRematerialization::ShapeSizeFunction& size_function, - int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion) { - HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, - copy_insertion); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350ad88ff31ece90ec419d90a55b965166..e2aaf18b3e482bbf777c594c7f5a22832be2ac17 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -17,16 +17,23 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { -class HloRematerialization { +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloPassInterface { public: using ShapeSizeFunction = std::function; @@ -37,10 +44,7 @@ class HloRematerialization { int64 after_bytes; }; - // Rematerialize HLO instructions in the given module to reduce peak memory - // use below memory_limit_bytes where memory use is defined as the total size - // of all live HLO instruction values. Parameters and constants are included - // in memory use estimates. Method parameters: + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. @@ -48,60 +52,34 @@ class HloRematerialization { // memory_limit_bytes: The threshold number of bytes to reduce memory use to // via rematerialization. // - // hlo_module: HLO module to rematerialize instructions in. - // - // sequence: Should point to an empty HloModuleSequence. Upon return - // contains the HLO instruction order which was used for - // rematerialization. This is the order in which HLO instructions should - // be emitted to minimize memory use. - // - // sizes: Optional outparam that indicates the peak memory usage of the HLO - // module before/after rematerialization. - // - // copy_insertion: If non-null, run copy elision after scheduling. This - // pass is used to eliminate copies that were inserted by copy insertion - // before HLO scheduling. - // - // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy - // insertion is integrated with HLO scheduling. - // - // Returns whether any instructions were rematerialized. If memory use is - // already below the given limit then no instructions are rematerialized and - // false is returned. - // - // CSE will undo the effects of this optimization and should not be run after - // this pass. In general, this pass should be run very late immediately before - // code generation. - static StatusOr RematerializeAndSchedule( - const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); - - protected: - HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, - const ShapeSizeFunction& size_function) - : scheduler_algorithm_(scheduler_algorithm), - size_function_(size_function) {} + // sizes: Pointer to data structure which records the peak memory usage of + // the HLO module before/after rematerialization. Value are set during + // Run(). Can be nullptr. + HloRematerialization(const ShapeSizeFunction& size_function, + int64 memory_limit_bytes, RematerializationSizes* sizes) + : size_function_(size_function), + memory_limit_bytes_(memory_limit_bytes), + sizes_(sizes) {} ~HloRematerialization() {} + absl::string_view name() const override { return "rematerialization"; } + // Runs rematerialization on the given module. Returns whether the module was - // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence - // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes, - CopyInsertion* copy_insertion); + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + StatusOr Run(HloModule* module) override; + protected: // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at @@ -122,6 +100,14 @@ class HloRematerialization { // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; + // The threshold number of bytes to reduce memory use to via + // rematerialization. + const int64 memory_limit_bytes_; + + // Pointer to data structure which records the peak memory usage of the HLO + // module before/after rematerialization + RematerializationSizes* sizes_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380953764b66135ad1c5fcee0d481c004..f7e82fb1f88e856305f6f481a451d4cd64ba4acf 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -36,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::_; -class HloRematerializationTest : public HloTestBase { +class HloRematerializationTest : public HloVerifiedTestBase { protected: // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: @@ -141,13 +141,16 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module) { TF_EXPECT_OK(verifier().Run(module).status()); - return HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + HloMemoryScheduler scheduler( + [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); }, + DefaultMemoryScheduler); + TF_EXPECT_OK(scheduler.Run(module).status()); + HloRematerialization remat(ByteSizeOf, memory_limit_bytes, + /*sizes=*/nullptr); + return remat.Run(module); } // Various shapes used in the canned computations. @@ -170,12 +173,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, module)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +189,13 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(module->schedule() + .sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +209,9 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, module)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +247,9 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, module)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +280,9 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, module)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +319,9 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, module)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,13 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +477,12 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +571,12 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( - /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, module)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 66ac1f66fd035074c69d070821a951fd0e357289..fa7f216321988137dcf9104a324f5f7789869aa5 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -118,16 +118,16 @@ StatusOr> HloRunner::TransferLiteralsToDevice( } StatusOr> HloRunner::TransferLiteralsToDevice( - const absl::Span> literals) { + const absl::Span literals) { std::vector literal_pointers; literal_pointers.reserve(literals.size()); for (const auto& literal : literals) { - literal_pointers.push_back(literal.get()); + literal_pointers.push_back(&literal); } return TransferLiteralsToDevice(literal_pointers); } -StatusOr> HloRunner::TransferLiteralFromDevice( +StatusOr HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { TF_ASSIGN_OR_RETURN( auto stream, backend().BorrowStream(backend().default_stream_executor())); @@ -135,7 +135,7 @@ StatusOr> HloRunner::TransferLiteralFromDevice( buffer); } -StatusOr> HloRunner::Execute( +StatusOr HloRunner::Execute( std::unique_ptr module, const absl::Span arguments, bool run_hlo_passes, ExecutionProfile* profile) { @@ -150,15 +150,15 @@ StatusOr> HloRunner::Execute( return TransferLiteralFromDevice(result); } -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes, ExecutionProfile* profile) { +StatusOr HloRunner::Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes, + ExecutionProfile* profile) { // Construct a vector of plain pointers for the arguments. std::vector argument_pointers; argument_pointers.reserve(arguments.size()); for (const auto& argument : arguments) { - argument_pointers.push_back(argument.get()); + argument_pointers.push_back(&argument); } return Execute( /*module=*/std::move(module), @@ -204,7 +204,7 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( /*profile=*/profile); } -StatusOr>> HloRunner::ExecuteReplicated( +StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { TF_ASSIGN_OR_RETURN( @@ -290,9 +290,9 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + Literal literal; TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, options.outfeed_shape, literal.get())); + executor, options.outfeed_shape, &literal)); if (options.outfeed_values != nullptr) { options.outfeed_values->push_back(std::move(literal)); } @@ -310,10 +310,10 @@ StatusOr>> HloRunner::ExecuteReplicated( argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; - std::vector> exec_results; + std::vector exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone()); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, backend().transfer_manager()->TransferLiteralFromDevice( streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 76d8b92bed484381a59d7f54e0a75bb7e75649ee..2e934bf66ae43ea412f242030b874dddb6d3722d 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -72,7 +72,7 @@ class HloRunner { // A pointer to a vector where the outfeed values will be stored. If // nullptr, the values will be read and discarded. - std::vector>* outfeed_values = nullptr; + std::vector* outfeed_values = nullptr; // Whether the HLO passes should be run on the input module. Usually // saved modules are coming from after the HLO pass pipeline, so triggering @@ -106,24 +106,23 @@ class HloRunner { StatusOr> TransferLiteralsToDevice( const absl::Span literals); StatusOr> TransferLiteralsToDevice( - const absl::Span> literals); - StatusOr> TransferLiteralFromDevice( - const ShapedBuffer& buffer); + const absl::Span literals); + StatusOr TransferLiteralFromDevice(const ShapedBuffer& buffer); // Executes the given module with given literals as input and returns the // result as a Literal. // // If run_hlo_passes is false, the module will be executed without Hlo // optimization. - StatusOr> Execute( - std::unique_ptr module, - const absl::Span arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); - StatusOr> Execute( - std::unique_ptr module, - const absl::Span> arguments, - bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + StatusOr Execute(std::unique_ptr module, + const absl::Span arguments, + bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr); // As Execute(), but accepts and returns device buffers instead of host // buffers. @@ -140,7 +139,7 @@ class HloRunner { // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as // value. - StatusOr>> ExecuteReplicated( + StatusOr> ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options); diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000000000000000000000000000000000..3fc5dbeb02a26134a7f255fa0b6ebda1dc41ce4d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,343 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +/* static */ StatusOr HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + tensorflow::gtl::FlatSet nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + it = sequences_.erase(it); + } else { + it++; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000000000000000000000000000000000..270fe6039f0afd119c76086de9a0596e0560e93e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +class HloModule; + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + explicit HloInstructionSequence( + absl::Span instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr ToProto() const; + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1424569ac1f62e4b965876141f1eb40be4f15bea --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 34cba6136ff3fe95529f3bcf594db7776c8bfd0a..e3f4a9852ace86c20610362aa6ad3c3d9c78de30 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index cba5db927a056c760e1c4a291d96cfdbca818029..e3ae82a070643895f2ecac0e64073a88b592f7c1 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata { bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 1e2b31a1f2bb4865faafc3d14e2b194e3aa171a1..6fd734a2b9e6c8c9fca76a944ca3df4c3b8a212f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -24,7 +24,7 @@ namespace { using ::tensorflow::GraphDef; -class HloTfGraphBuilderTest : public HloTestBase { +class HloTfGraphBuilderTest : public HloVerifiedTestBase { protected: HloTfGraphBuilderTest() {} HloTfGraphBuilder generator_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 95516dec74bd253212901a3d9a92285d11fe122f..50f39cbcb55e29a2654ed8c745ea24ee2e0ab899 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers(), - convolution->feature_group_count())); + convolution->feature_group_count(), convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } @@ -1123,6 +1123,11 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + // If the module has a schedule, it must be valid. + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Verify()); + } + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 0cac210c2413e979300e191cb54860bcd0ab79b5..8f0423bb1c72ceb209437116a898d027f4d2c657 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -290,8 +290,8 @@ TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32))), padding_config)); auto module = CreateNewModule(); @@ -314,8 +314,8 @@ TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { padding_config.add_dimensions()->set_interior_padding(-1); builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {100}), param, - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(F32).CloneToUnique())), + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())), padding_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index a4de02a89039e07b22b1ad8c268c2f760aa95880..06f0e1ed25e71659a61e6de8a84e52cf70064eae 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -165,6 +165,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + instr->precision_config(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else { @@ -917,7 +918,7 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) TF_ASSIGN_OR_RETURN( - std::unique_ptr inner_broadcast_result, + Literal inner_broadcast_result, broadcast_const_operand->literal().Broadcast( scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); @@ -927,12 +928,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + opcode, scalar_indexed_const->literal(), inner_broadcast_result))); } else { TF_ASSIGN_OR_RETURN( literal_for_new_source, TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( - opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + opcode, inner_broadcast_result, scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); @@ -1030,7 +1031,8 @@ bool CanFoldDotIntoIndexedArray( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, lhs->literal(), *rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting LHS // dimension "went". @@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( new_dim_numbers.set_rhs_contracting_dimensions( 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, *lhs->literal(), rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting RHS // dimension "went". @@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( } StatusOr IndexedArrayAnalysis::ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, - Array* rhs) { + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant @@ -1119,6 +1124,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( dynamic_cast(lhs)) { if (auto* rhs_constant = dynamic_cast(rhs)) { return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + precision_config, lhs_indexed_array, rhs_constant); } } @@ -1126,7 +1132,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( if (auto* rhs_indexed_array = dynamic_cast(rhs)) { if (auto* lhs_constant = dynamic_cast(lhs)) { - return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, + precision_config, lhs_constant, rhs_indexed_array); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index dcfb7255358ae08660fe2c6eae5af9f10370e762..df9cbab915cc037cec682238886fb524eaeb2c90 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -267,14 +267,17 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); StatusOr ComputeArrayForDot(const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another @@ -344,21 +347,19 @@ class IndexedArrayAnalysis { } } - Literal* TakeOwnership(std::unique_ptr literal) { + Literal* TakeOwnership(Literal literal) { owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } - StatusOr TakeOwnership( - StatusOr> literal_or_error) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - std::move(literal_or_error)); + StatusOr TakeOwnership(StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); owned_literals_.push_back(std::move(literal)); - return owned_literals_.back().get(); + return &owned_literals_.back(); } std::vector> owned_tensors_; - std::vector> owned_literals_; + std::vector owned_literals_; tensorflow::gtl::FlatMap cache_; }; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 5695bc242057c037a1999e7d63f5b4f21b5f658a..7e967f035c1054e22d10790188a5a232ca8e751a 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloTestBase; +using InlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` TEST_F(InlinerTest, MapMax) { @@ -64,14 +64,14 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } // Test that `constant` function is changed to `broadcast`. @@ -98,14 +98,14 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -136,14 +136,14 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEntryComputation(std::move(computation)); Inliner inliner; - EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); // Verify execution on CPU. - auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); auto expected = LiteralUtil::CreateR1({3, 1, -1, -3}); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8c907eae0cbe7c3764a2bfe8fed6b6098931de38..3fdc2cee9aad0fe70f66920f757ee5c52bba711f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible( return do_not_duplicate; } +namespace { + +// A FusionQueue that uses reverse post order. +// +// We want to be able to remove arbitrary instructions from the post order and +// also compare positions of instructions in the post order. To make this +// possible, create vector of instructions in post order and create a map from +// HloInstruction* to the instruction's index in the vector. An instruction is +// "removed" from the vector by setting it's element to nullptr. +class ReversePostOrderFusionQueue : public FusionQueue { + public: + explicit ReversePostOrderFusionQueue(HloComputation* computation) { + post_order_ = computation->MakeInstructionPostOrder(); + + for (size_t i = 0; i < post_order_.size(); ++i) { + InsertOrDie(&post_order_index_, post_order_[i], i); + } + } + + std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() override { + // Instructions are "removed" from the post order by nulling out the element + // in the vector, so if the pointer is null, continue to the next + // instruction in the sort. + while (!post_order_.empty() && post_order_.back() == nullptr) { + post_order_.pop_back(); + } + if (post_order_.empty()) { + return std::pair>{nullptr, {}}; + } + // We want to iterate in reverse post order, so remove from the back of the + // vector. + HloInstruction* instruction = post_order_.back(); + post_order_.pop_back(); + + CHECK(instruction != nullptr); + // Remove instruction from the index map to ensure the vector and map stay + // consistent. + post_order_index_.erase(instruction); + + // Consider each operand of this instruction for fusion into this + // instruction. We want to consider the operands in a particular order to + // avoid creating duplicate instruction clones in the fusion instruction. + // For example, consider the following expression: + // + // A = ... + // B = op(A) + // C = op(A, B) + // + // If we are considering the operands of C for fusion into C. We might + // fuse A or B first. If we fuse A first, we get: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // C' = op(A', B) } + // + // Where A' and C' are clones of A and C, respectively. Now only B is an + // operand of the fusion instruction C_fusion, so then we fuse B: + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // B' = op(A) + // C' = op(A', B') } + // + // Now A is an operand of C_fusion again, so we then fuse A (again!): + // + // A = ... + // B = op(A) + // C_fusion = { A' = ... + // A" = .. + // B' = op(A") + // C' = op(A', B') } + // + // We prevent this duplication by considering the operands in the order + // they appear int the queue. In the example, this ensures that B will be + // considered before A. + // + // We store the original indices of the operands to pass to ShouldFuse. + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the queue. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) { + continue; + } + sorted_operand_numbers.push_back(i); + } + std::sort( + sorted_operand_numbers.begin(), sorted_operand_numbers.end(), + [&](int64 i, int64 j) { + // Instructions with higher priority in the queue come first. + return ( + FindOrDie(post_order_index_, instruction->mutable_operand(i)) > + FindOrDie(post_order_index_, instruction->mutable_operand(j))); + }); + return std::make_pair(instruction, sorted_operand_numbers); + } + + void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) override { + // Fusing an instruction into a fusion instruction can change the operand + // set of the fusion instruction. For simplicity just re-enqueue the + // instruction and reconsider it for further fusion in the next iteration. + InsertOrDie(&post_order_index_, fusion, post_order_.size()); + post_order_.push_back(fusion); + } + + void RemoveInstruction(HloInstruction* instruction) override { + post_order_[FindOrDie(post_order_index_, instruction)] = nullptr; + post_order_index_.erase(instruction); + } + + private: + std::vector post_order_; + tensorflow::gtl::FlatMap post_order_index_; +}; + +} // namespace + +std::unique_ptr InstructionFusion::GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer) { + return absl::make_unique(computation); +} + StatusOr InstructionFusion::Run(HloModule* module) { VLOG(2) << "Before instruction fusion:"; XLA_VLOG_LINES(2, module->ToString()); @@ -306,111 +439,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { computation_ = computation; reachability_ = computation_->ComputeReachability(); - // We want to be able to remove arbitrary instructions from the post order - // and also compare positions of instructions in the post order. To make - // this possible, create vector of instructions in post order and create a - // map from HloInstruction* to the instruction's index in the vector. An - // instruction is "removed" from the vector by setting it's element to - // nullptr. - std::vector post_order = - computation_->MakeInstructionPostOrder(); - - tensorflow::gtl::FlatMap post_order_index; - for (size_t i = 0; i < post_order.size(); ++i) { - InsertOrDie(&post_order_index, post_order[i], i); - } - - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); + HloInstructionSet do_not_duplicate = + ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder()); + auto fusion_queue = + GetFusionQueue(computation_, [&](HloInstruction* producer) { + return do_not_duplicate.count(producer) > 0; + }); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all // edges. When we fuse an edge, we create a copy of the producer inside the // fusion instruction. - while (!post_order.empty()) { - // We want to iterate in reverse post order, so remove from the back of - // the vector. - HloInstruction* instruction = post_order.back(); - post_order.pop_back(); - - // Instructions are "removed" from the post order by nulling out the - // element in the vector, so if the pointer is null, continue to the next - // instruction in the sort. + while (true) { + auto next_entry = + fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder(); + auto instruction = next_entry.first; if (instruction == nullptr) { - continue; + break; } - // Remove instruction from the index map to ensure the vector and map stay - // consistent. - post_order_index.erase(instruction); - if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } - // Consider each operand of this instruction for fusion into this - // instruction. We want to consider the operands in a particular order to - // avoid creating duplicate instruction clones in the fusion instruction. - // For example, consider the following expression: - // - // A = ... - // B = op(A) - // C = op(A, B) - // - // If we are considering the operands of C for fusion into C. We might - // fuse A or B first. If we fuse A first, we get: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // C' = op(A', B) } - // - // Where A' and C' are clones of A and C, respectively. Now only B is an - // operand of the fusion instruction C_fusion, so then we fuse B: - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // B' = op(A) - // C' = op(A', B') } - // - // Now A is an operand of C_fusion again, so we then fuse A (again!): - // - // A = ... - // B = op(A) - // C_fusion = { A' = ... - // A" = .. - // B' = op(A") - // C' = op(A', B') } - // - // We prevent this duplication by considering the operands in the reverse - // order they appear in the instruction post order. In the example, this - // ensures that B will be considered before A. - // - // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers; - sorted_operand_numbers.reserve(instruction->operands().size()); - for (int i = 0; i < instruction->operands().size(); ++i) { - // This will happen if we have two possible instructions to fuse the - // same operand into; once the operand is fused into one instruction, - // the other instruction will get a new get-tuple-element as its - // operand, which is not in the post-order index. - // TODO(tjoerg): Look into fusing past these multi-output fuse points. - if (post_order_index.find(instruction->mutable_operand(i)) == - post_order_index.end()) { - continue; - } - sorted_operand_numbers.push_back(i); - } - std::sort( - sorted_operand_numbers.begin(), sorted_operand_numbers.end(), - [&](int64 i, int64 j) { - // Instructions with higher indices in the post order come - // first. - return ( - FindOrDie(post_order_index, instruction->mutable_operand(i)) > - FindOrDie(post_order_index, instruction->mutable_operand(j))); - }); + std::vector& sorted_operand_numbers = next_entry.second; for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); @@ -425,32 +478,31 @@ StatusOr InstructionFusion::Run(HloModule* module) { // TODO(tjoerg): Consider making multi-output fusion the default. if (ShouldFuse(instruction, i) && do_not_duplicate.count(operand) == 0) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = Fuse(operand, instruction); } else if (ShouldFuseIntoMultiOutput(instruction, i) && !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_queue->PreFusion(operand, instruction); fusion_instruction = FuseIntoMultiOutput(operand, instruction); } else { continue; } - // Fusing an instruction into a fusion instruction can change the - // operand set of the fusion instruction. For simplicity just push the - // instruction to the top of the post_order and reconsider it for - // further fusion in the next iteration of the outer loop. - post_order.push_back(fusion_instruction); - InsertOrDie(&post_order_index, fusion_instruction, - post_order.size() - 1); + fusion_queue->OnFusingInstruction(fusion_instruction, operand, + instruction); changed = true; if (operand->user_count() == 0) { - // Operand is now dead. Remove from post order by setting its - // location to nullptr. - post_order[FindOrDie(post_order_index, operand)] = nullptr; - post_order_index.erase(operand); - + do_not_duplicate.erase(operand); + // Operand is now dead. Remove from queue. + fusion_queue->RemoveInstruction(operand); // Remove from computation. TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand)); } + + if (fusion_instruction != instruction) { + do_not_duplicate.erase(instruction); + } break; } } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 00b658959a2cceeb30d2ec03f243119ec0a8ee47..c1fde8ecfc04792c6c17ebd83190486ef720175a 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -24,6 +24,33 @@ limitations under the License. namespace xla { +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in @@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface { static bool IsExpensive(const HloInstruction& instruction); protected: + // Returns a FusionQueue that implements custom order of instructions being + // fused. The default implementation processes consumers in reverse post + // order. + virtual std::unique_ptr GetFusionQueue( + HloComputation* computation, + const std::function& skip_producer); + // Returns whether the given producer instruction should be fused into the // given consumer instruction. producer is necessarily an operand of consumer. // Derived classes should define this method to specify which instructions diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 5dea12476849db6f7a9a9214398b4e57262aeda0..a06d6113e84630df14ff68280c248cccb9afaf06 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -73,30 +73,29 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // Transform the ShapedBuffer arguments into literals which the evaluator // consumes. - std::vector> arg_literals; + std::vector arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. - std::unique_ptr result_literal; + Literal result_literal; { tensorflow::mutex_lock lock(evaluator_lock_); - TF_ASSIGN_OR_RETURN(result_literal, - evaluator_->Evaluate>( - *computation, arg_literals)); + TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate( + *computation, arg_literals)); } // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, transfer_manager->AllocateScopedShapedBuffer( - result_literal->shape(), run_options->allocator(), + result_literal.shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - run_options->stream(), *result_literal, result)); + run_options->stream(), result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 6e17711f575b24ffcfcbf1a78bb803603b001adf..082bf8bffed484244139e79f4d3fe30ca091d8ac 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -855,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction, ? instruction.sharding().GetSubSharding(instruction.shape(), index) : instruction.sharding(); // We propagate the sharding to the copied instruction only if it is a - // special sharding, like tiled ones, or special devices like the - // HostCompute module. + // special sharding, like tiled ones. // Otherwise it is preferable to leave the new instruction without device, // and let the automatic device placer to choose the best location. auto device = sharding.UniqueDevice(); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 021fe630ff6329c51e297d0bb2bee8269a42904b..752a61476dd7892a2b7f531c4057015f48fc4758 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,7 +49,7 @@ namespace { using ::testing::ElementsAre; -class LayoutAssignmentTest : public HloTestBase { +class LayoutAssignmentTest : public HloVerifiedTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout, @@ -91,7 +91,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayout) { *computation_layout.mutable_parameter_layout(0) = shape_layout; *computation_layout.mutable_parameter_layout(1) = shape_layout; *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout())); @@ -127,7 +127,7 @@ TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) { *computation_layout.mutable_parameter_layout(1) = row_major; *computation_layout.mutable_result_layout() = col_major; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout())); EXPECT_TRUE(LayoutUtil::Equal( @@ -145,7 +145,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); auto constant_literal2 = LiteralUtil::CreateR2WithLayout( {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); - Shape ashape = constant_literal1->shape(); + Shape ashape = constant_literal1.shape(); auto constant1 = builder.AddInstruction( HloInstruction::CreateConstant(std::move(constant_literal1))); @@ -172,7 +172,7 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { ComputationLayout computation_layout(computation->ComputeProgramShape()); *computation_layout.mutable_result_layout() = shape_layout; - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::Equal( layout, fusion->fused_parameter(0)->shape().layout())); @@ -213,7 +213,7 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape())); @@ -243,7 +243,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( - tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -255,7 +255,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape( result_shape)); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape())); } @@ -294,7 +294,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { result_shape)); LayoutAssignment layout_assignment(&computation_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); // Layout assignment should have deep copied the result of the computation to // address the layout conflict. This results in several Tuple() and @@ -310,7 +310,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { EXPECT_TRUE( AlgebraicSimplifier(/*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return false; }) - .Run(module.get()) + .Run(module) .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); // Verify layout of the root and the root's operands. @@ -352,7 +352,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); auto log_minor_to_major = AsInt64Slice(log->shape().layout().minor_to_major()); @@ -393,7 +393,7 @@ TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) { *computation_layout.mutable_parameter_layout(0) = ShapeLayout(ashape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE( LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout())); @@ -432,7 +432,7 @@ TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) { ShapeLayout(input_shape_with_layout); *computation_layout.mutable_result_layout() = ShapeLayout(output_shape_with_layout); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1, 2)); @@ -457,13 +457,13 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, f32_4, "param")); auto broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_34, param, {3})); + HloInstruction::CreateBroadcast(f32_34, param, {1})); auto transpose = builder.AddInstruction( HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0})); auto tanh = builder.AddInstruction( HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast)); auto broadcast2 = builder.AddInstruction( - HloInstruction::CreateBroadcast(f32_234, tanh, {2})); + HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2})); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({transpose, broadcast2})); auto module = CreateNewModule(); @@ -485,7 +485,7 @@ TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) { *computation_layout.mutable_result_layout() = ShapeLayout(ShapeUtil::MakeTupleShape( {transpose_shape_with_layout, broadcast2_shape_with_layout})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1)); EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0)); @@ -551,7 +551,7 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { *computation_layout.mutable_parameter_layout(1) = ShapeLayout(param1_shape_with_layout); OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout); - EXPECT_IS_OK(layout_assignment.Run(module.get()).status()); + EXPECT_IS_OK(layout_assignment.Run(module).status()); EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode()); EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(), @@ -575,7 +575,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -593,7 +593,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { HloComputation* computation = module->AddEntryComputation(builder.Build(transpose)); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } @@ -659,18 +659,18 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); - module = + std::unique_ptr compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); EXPECT_EQ(Status::OK(), backend() .compiler() - ->RunBackend(std::move(module), + ->RunBackend(std::move(compiled_module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .status()); @@ -699,9 +699,9 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), ShapeUtil::MakeTupleShape({ @@ -713,19 +713,19 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { param_shape)); computation_layout.mutable_result_layout()->ResetLayout( LayoutUtil::MakeLayout({2, 1, 0})); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(&module(), &computation_layout); - EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(0) .layout() .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(FindInstruction(module.get(), "gte1") + EXPECT_THAT(FindInstruction(&module(), "gte1") ->shape() .tuple_shapes(1) .layout() @@ -785,7 +785,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { HloComputation* computation = module->AddEntryComputation(builder.Build()); ComputationLayout computation_layout(computation->ComputeProgramShape()); - AssignLayouts(module.get(), &computation_layout); + AssignLayouts(module, &computation_layout); const HloInstruction* true_root = true_computation->root_instruction(); const HloInstruction* false_root = false_computation->root_instruction(); @@ -812,7 +812,7 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); LayoutAssignment layout_assignment(&computation_layout); - Status error_status = layout_assignment.Run(module.get()).status(); + Status error_status = layout_assignment.Run(module).status(); EXPECT_FALSE(error_status.ok()); EXPECT_THAT( error_status.error_message(), @@ -839,9 +839,9 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); + ParseAndVerifyModule(module_str); ComputationLayout computation_layout( - module->entry_computation()->ComputeProgramShape()); + module().entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); TF_ASSERT_OK( @@ -851,14 +851,13 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { LayoutUtil::MakeLayout({1, 0})); ChannelLayoutConstraints channel_constraints; - AssignLayouts(module.get(), &computation_layout, &channel_constraints); + AssignLayouts(&module(), &computation_layout, &channel_constraints); - EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); - EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0)); - EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::GetSubshape( - FindInstruction(module.get(), "send")->shape(), {0}), - ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); + EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0)); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { @@ -873,19 +872,19 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); - module = + ParseAndVerifyModule(module_str); + auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto slice = FindInstruction(module.get(), "slice0"); - EXPECT_EQ(slice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, op::Add(op::Parameter(), + op::Slice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -901,19 +900,21 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); - module = + ParseAndVerifyModule(module_str); + auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto dslice = FindInstruction(module.get(), "dslice0"); - EXPECT_EQ(dslice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -930,19 +931,21 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { } )"; - auto module = ParseHloString(module_str).ValueOrDie(); - module = + ParseAndVerifyModule(module_str); + auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto concat = FindInstruction(module.get(), "concat0"); - EXPECT_EQ(concat->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::Concatenate(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, @@ -959,16 +962,40 @@ TEST_F(LayoutAssignmentTest, } )"; - auto module = ParseHloString(module_str).ValueOrDie(); - module = + ParseAndVerifyModule(module_str); + auto compiled_module = backend() .compiler() - ->RunHloPasses(std::move(module), backend().default_stream_executor(), + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { + const char* module_str = R"( + HloModule PropagatingLayoutFromResultToOperand + + ENTRY PropagatingLayoutFromResultToOperand { + par0 = f32[4,5]{1,0} parameter(0) + ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]} + } + )"; - auto copy = FindInstruction(module.get(), "copy.1"); - EXPECT_EQ(copy, nullptr); + ParseAndVerifyModule(module_str); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(module().Clone(), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); + EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), + op::ShapeWithLayout(shape_copy)))); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 7d49b8d6c2c902ee38d72f72b3da9d190cc65bf0..a60643bc754f896d096b3ca4e1216e77d7e384c6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -75,6 +75,16 @@ void EmitTuple(const IrArray& tuple, absl::Span operands, } } +void EmitTuple(const IrArray& tuple, absl::Span buffers, + llvm::IRBuilder<>* b, llvm::Module* module) { + std::vector buffer_ptrs; + buffer_ptrs.reserve(buffers.size()); + absl::c_transform( + buffers, std::back_inserter(buffer_ptrs), + [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); }); + llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module); +} + llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, int alignment, llvm::Value* operand, llvm::IRBuilder<>* b, llvm::Module* module) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index 887fb613717ef780d6903a3b97bfdf4b735c4f82..94340b91d8eeea1ba4681c2e49c0894eab2f6cc0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -68,6 +68,11 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred, void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b, llvm::Module* module); +// Similar to EmitTuple above, except that the output buffers are provided in +// the form of IrArray. +void EmitTuple(const IrArray& tuple, absl::Span buffers, + llvm::IRBuilder<>* b, llvm::Module* module); + // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. A GetTupleElement instruction // forwards the pointer to underlying tuple element buffer at the given index. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index f0e2566a3f9ef5c0be8af46d3a16cd9c72793366..922ebdf0e3f0e79674c5a632c873627845a606ec 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -68,9 +68,9 @@ Status RecordArguments(const absl::Span arguments, module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, *argument)); - *module->add_arguments() = literal->ToProto(); + *module->add_arguments() = literal.ToProto(); } return Status::OK(); } @@ -80,9 +80,9 @@ Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, + Literal literal, transfer_manager->TransferLiteralFromDevice(stream, result)); - *module->mutable_result() = literal->ToProto(); + *module->mutable_result() = literal.ToProto(); return Status::OK(); } @@ -928,16 +928,15 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, + Literal result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( stream.get(), *shaped_buffer)); - if (LayoutUtil::LayoutsInShapesEqual(*return_shape, - result_literal->shape())) { - *result->mutable_literal() = result_literal->ToProto(); + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) { + *result->mutable_literal() = result_literal.ToProto(); } else { *result->mutable_literal() = - result_literal->Relayout(*return_shape)->ToProto(); + result_literal.Relayout(*return_shape).ToProto(); } return Status::OK(); } @@ -959,9 +958,9 @@ std::unique_ptr CloneShapedBufferOnDevice( Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -983,7 +982,7 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - stream.get(), *literal, shaped_buffer)); + stream.get(), literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1018,10 +1017,10 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(arg->literal())); - return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, *literal); + return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor, + literal); } Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, @@ -1049,8 +1048,8 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( - executor, arg->shape_with_layout(), *literal)); - *result->mutable_literal() = literal->ToProto(); + executor, arg->shape_with_layout(), literal)); + *result->mutable_literal() = literal.ToProto(); return Status::OK(); } @@ -1085,18 +1084,17 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, HloModule::CreateFromProto(arg->computation(), config)); HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate>( - *module, /*arg_literals=*/{})); + TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate( + *module, /*arg_literals=*/{})); // Since the result layout is non-effective to the Evaluator results, explicit // relayout here. // // TODO(b/77824332): Make HloEvaluator take care of the re-layout. if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); + result_literal = result_literal.Relayout(arg->output_layout()); } - *result->mutable_literal() = result_literal->ToProto(); + *result->mutable_literal() = result_literal.ToProto(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 26117498621450d56259507761b6b0a6ea8d3a15..74bdf2a2e3982bc9be29bae037e385fede578ae5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); @@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + if (kernel_output_features % feature_group_count > 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "feature_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index a28345acefb8fca1c8b6444f431f932c23c57ce4..96a0ee165d46753da4fef119e7072f66637bf2c4 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -108,9 +108,9 @@ class ShapeInference { // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. static StatusOr InferFftShape(const Shape& in, FftType fft_type, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index cc92e58ef867ee716714fff4fdab07b9cb836d00..864ed43118cd066f6ce14cd808b873f137b8414a 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc deleted file mode 100644 index dd53c7531bea4273b5f8dc1c993e7720eb1afeb2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/source_map_util.h" - -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { -namespace source_map_util { -namespace { - -Status InvalidParameterArgumentV(const OpMetadata& op_metadata, - const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - if (!op_metadata.source_file().empty()) { - absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), - op_metadata.source_line()); - } - return InvalidArgument("%s", message); -} - -} // namespace - -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidParameterArgumentV(op_metadata, format, args); - va_end(args); - return result; -} - -Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) { - va_list args; - va_start(args, format); - if (executable != nullptr && executable->has_module()) { - const HloModule& module = executable->module(); - const HloComputation& computation = *module.entry_computation(); - HloInstruction* param = computation.parameter_instruction(parameter_number); - const OpMetadata& metadata = param->metadata(); - Status result = InvalidParameterArgumentV(metadata, format, args); - va_end(args); - return result; - } - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -} // namespace source_map_util -} // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index b8d2d546e5d4dc67e3f314dfc6dcd4e8df5451c5..a21e586efadb85d18e88e44999283b28f7f65eac 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -42,9 +42,9 @@ TransferManager::GetPlatformTransferManagers() { return r; } -StatusOr> TransferManager::TransferLiteralFromDevice( +StatusOr TransferManager::TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer) { - StatusOr> ret; + StatusOr ret; se::Stream* substream = stream->GetOrCreateSubStream(); substream->ThenWaitFor(stream); @@ -63,7 +63,7 @@ StatusOr> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferLiteralFromDevice( @@ -99,10 +99,10 @@ Status TransferManager::TransferLiteralToDevice( return substream->BlockHostUntilDone(); } -StatusOr> TransferManager::TransferArrayFromDevice( +StatusOr TransferManager::TransferArrayFromDevice( se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source) { - StatusOr> ret; + StatusOr ret; // Implement the synchronous version by waiting on the asynchronous version. // Use a substream so that if we are called from a HostCallback we don't // deadlock. @@ -122,7 +122,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return absl::make_unique(std::move(literal)); + return std::move(literal); } Status TransferManager::TransferArrayToDevice( diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 21725946b3629a4495d8ad6cc1529d712d22e0af..f952e64af2b675b9c0f8a30e9a2bc3c855e34efa 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -57,7 +57,7 @@ class TransferManager { // without waiting for any other operation on a stream to complete. // // This function should be avoided in favor of the asynchronous version below. - virtual StatusOr> TransferLiteralFromDevice( + virtual StatusOr TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer); virtual Status TransferLiteralFromDevice( se::Stream* stream, const ShapedBuffer& device_buffer, @@ -113,9 +113,9 @@ class TransferManager { Status TransferArrayToDeviceAsync(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); - StatusOr> TransferArrayFromDevice( - se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& source); + StatusOr TransferArrayFromDevice(se::Stream* stream, + const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 530f40e4b2f9c7c19fa29dad28a077b9d4d68a71..7c1f4b5cc67dd2a84271b4f2b8015fdb2ff6e846 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { } std::unique_ptr new_dot = HloInstruction::CreateDot( - dot->shape(), new_lhs, new_rhs, new_dim_numbers); - new_dot->set_precision_config(dot->precision_config()); + dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - new_conv->set_precision_config(convolution.precision_config()); + convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), + convolution.window(), new_dnums, convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 58f767e913fbc0023e0c45a4f0e82ecefeeef2d6..79b5c09abb355cd067a4891af558c8c44d80d88e 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -240,10 +240,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -293,10 +295,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -351,10 +355,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -415,10 +421,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a32d1f9026e8beae77b5b40241995707ff62231e..e9a07b14ed685fa4388aca583395370a60176cca 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -555,10 +555,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + Literal elements[] = {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})}; + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::MakeTuple({&elements[0], &elements[1]}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index 39b693872da6bd985d95c2abc9519662c838a3f5..516754e2110ee50a597818c4a8bcfbfbb76c5cec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloVerifiedTestBase { protected: void Run(HloModule* module, bool change_expected) { TupleSimplifier simplifier; @@ -68,7 +68,7 @@ TEST_F(TupleSimplifierTest, TupleOfParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { @@ -81,7 +81,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); } TEST_F(TupleSimplifierTest, GteOfTuple) { @@ -103,7 +103,7 @@ TEST_F(TupleSimplifierTest, GteOfTuple) { EXPECT_THAT(computation->root_instruction(), gte); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param1); } @@ -131,7 +131,7 @@ TEST_F(TupleSimplifierTest, GteOfTupleChain) { EXPECT_THAT(computation->root_instruction(), op::Negate(op::GetTupleElement(op::Tuple()))); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); } @@ -162,7 +162,7 @@ TEST_F(TupleSimplifierTest, NestedGteOfTuples) { EXPECT_THAT(computation->root_instruction(), element); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), param); } @@ -187,7 +187,7 @@ TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/true); + Run(module, /*change_expected=*/true); EXPECT_THAT(computation->root_instruction(), tuple_param); } @@ -212,7 +212,7 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); - Run(module.get(), /*change_expected=*/false); + Run(module, /*change_expected=*/false); EXPECT_THAT(computation->root_instruction(), tuple); } @@ -281,7 +281,7 @@ TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { entry = module->AddEntryComputation(builder.Build()); } - Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + Run(module, /*change_expected=*/true, /*exclude_entry=*/true); EXPECT_THAT(c0->root_instruction(), p0); EXPECT_THAT(c1->root_instruction(), p1); diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index c3c2603c7eb58d3e57346d2ea1e0058f8e5d7fe8..541b117e0299c94de330604ec5c16e20f07c425f 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -183,8 +183,7 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, HloEvaluator evaluator(/*max_loop_iterations=*/0); auto* while_init = while_op->mutable_operand(0); auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); + StatusOr indvar_init_result = evaluator.Evaluate(indvar_init); if (!indvar_init_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable init: " << indvar_init_result.status(); @@ -197,31 +196,27 @@ optional ComputeWhileLoopTripCount(HloInstruction* while_op, auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); // The initial value of the induction variable. - std::unique_ptr indvar_iter_val = - std::move(indvar_init_result).ValueOrDie(); + Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie(); for (int64 trip_count = 0; trip_count != max_value_returned + 1; ++trip_count) { auto* while_cond = while_op->while_condition(); auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions( - while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}}); + StatusOr result = evaluator.EvaluateWithSubstitutions( + while_cond_root, {{while_cond_indvar, &indvar_iter_val}}); if (!result.ok()) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - if (result.ValueOrDie()->data() == absl::Span{false}) { + if (result.ValueOrDie().data() == absl::Span{false}) { VLOG(2) << "Loop has static trip count of " << trip_count; return trip_count; } // Calculate the value of the induction variable after one iteration of the // loop, and check whether the while condition is true with this new value. - StatusOr> indvar_next_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, - {{while_body_indvar, indvar_iter_val.get()}}); + StatusOr indvar_next_result = evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}}); if (!indvar_next_result.ok()) { VLOG(2) << "Couldn't evaluate induction variable update: " << indvar_next_result.status(); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index aab11806621746141f4302f39a780fcdbab99fc1..56145822be70f391ac3eaab5fc17db4a80e1b9cc 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 52c895e8d4b2aa55b55df41b7139b00c576d6e99..df610102b4c7fa08c0b7030124939009130f89f4 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -224,14 +224,13 @@ class ShapeTree { // REQUIRES: index must exist in the ShapeTree. iterator find(ShapeIndexView index) { Node* element = Lookup(index); - return iterator(&nodes_, typename std::vector::iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.begin() + (element - &nodes_[0]); + return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); - return iterator(&nodes_, - typename std::vector::const_iterator(element), - /*iterate_leaves_only=*/false); + auto element_iter = nodes_.cbegin() + (element - &nodes_[0]); + return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false); } // Returns the number of leaf nodes in the tree. diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 9772c06bce32cef0d79a036b525c3606ea60e31b..96c80fd577e2601c972e374a153f4f0706902ec2 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -441,6 +441,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return count; } +/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type) { + if (shape.element_type() == primitive_type) { + return true; + } + for (const Shape& element_shape : shape.tuple_shapes()) { + if (HasPrimitiveType(element_shape, primitive_type)) { + return true; + } + } + return false; +} + /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8234fcdd3f57978b94630d4e2880826dd678389f..623ae39de819ebecdc8aee27a2b31176421ef020 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -180,6 +180,10 @@ class ShapeUtil { // As ElementsIn(), but recurses through tuples. static int64 ElementsInRecursive(const Shape& shape); + // Returns true if shape has the primitive type, recurses through tuples. + static bool HasPrimitiveType(const Shape& shape, + PrimitiveType primitive_type); + // Returns true if 'shape' is an array with zero elements. static bool IsZeroElementArray(const Shape& shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 6ca4085aaf3bd1c181da3b94aa6c570e21172d0a..c622ecdca1fd66604d1a6ceaf705f2e70edaee55 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } +TEST(ShapeUtilTest, HasPrimitiveType) { + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32)); + EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}), + S32)); + EXPECT_TRUE(ShapeUtil::HasPrimitiveType( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}), + S16)); +} + TEST(ShapeUtilTest, IsZeroElementArray) { EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 36b8fb26440f0f71207cc9b2af4d14f21e618cfe..30e3077edb93e1ac740c1d863aacce975ad4c8a5 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -75,7 +75,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_headers_lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -648,6 +647,7 @@ xla_test( ], shard_count = 48, tags = [ + "broken", "manual", "notap", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 0bf4556b437fb1717a9c9773834fa3031cfbd6ea..c257566fb218d4769aec0c793efb9256b023b7ea 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -41,7 +41,6 @@ limitations under the License. namespace xla { namespace { - class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; @@ -227,10 +226,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0x8000000000000000LL, 0x8000000000000000LL, 1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, @@ -241,10 +240,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { 0, 1, 0x8000000000000000LL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); @@ -267,10 +266,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 1, 0, -1}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = - client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, @@ -280,10 +279,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = - client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); @@ -299,16 +298,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000ULL)}; - std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); - auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + Literal lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; - std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); - auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + Literal rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); Lt(lhs_param, rhs_param); - ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); + ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); } TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { @@ -321,16 +320,16 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); - auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param"); + auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); @@ -1422,12 +1421,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); - auto param = Parameter(&b, 0, param_literal->shape(), "param"); + auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } @@ -1450,14 +1449,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); @@ -1475,14 +1474,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); @@ -1500,14 +1499,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); @@ -1525,14 +1524,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); @@ -1551,20 +1550,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + client_->TransferToServer(literal2).ConsumeValueOrDie(); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); @@ -1583,21 +1582,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); @@ -1616,21 +1615,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); @@ -1650,26 +1649,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; - std::unique_ptr literal0 = LiteralUtil::CreateR1(values0); + Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = LiteralUtil::CreateR1(values1); + Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - std::unique_ptr literal2 = LiteralUtil::CreateR1(values2); + Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = - client_->TransferToServer(*literal2).ConsumeValueOrDie(); + client_->TransferToServer(literal2).ConsumeValueOrDie(); - std::unique_ptr literal3 = LiteralUtil::CreateR1(values3); + Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = - client_->TransferToServer(*literal3).ConsumeValueOrDie(); + client_->TransferToServer(literal3).ConsumeValueOrDie(); - auto param0 = Parameter(&b, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&b, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&b, 2, literal2->shape(), "param2"); - auto param3 = Parameter(&b, 3, literal3->shape(), "param2"); + auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); + auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); @@ -2096,18 +2095,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, @@ -2118,18 +2117,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); @@ -2140,13 +2139,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); - auto p = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, @@ -2206,9 +2205,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( @@ -2239,7 +2238,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, @@ -2252,16 +2251,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::exp(input_literal->Get({i}))); + expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2273,7 +2272,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // implementation on XLA CPU. XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR1( + Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, @@ -2290,16 +2289,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; - int64 input_size = input_literal->shape().dimensions(0); + int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(std::log(input_literal->Get({i}))); + expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, @@ -2465,10 +2464,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{true, true}, {true, false}}), + LiteralUtil::CreateR2({{true, false}, {false, false}})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { @@ -2821,10 +2820,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); - auto a = ConstantLiteral(&builder, *a_literal); + Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); + auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); @@ -2886,11 +2884,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); - auto y = Parameter(&builder, 1, y_literal->shape(), "y"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); + auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index ac90a3adb6dbad30e3ef0b11438fb9a6fd6f8574..bc2ba151a38f1ab000b342dcd4bdd8f53d9ce9a9 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -63,7 +63,7 @@ class BatchNormalizationTest {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_)); + input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -242,14 +242,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { @@ -267,14 +266,13 @@ XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - LiteralUtil::CreateR1({4, 5}).get(), - LiteralUtil::CreateR1({5, 5}).get()}); + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}), + LiteralUtil::CreateR1({4, 5}), + LiteralUtil::CreateR1({5, 5})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { @@ -298,13 +296,12 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { BatchNormTraining(h0, h1, h2, /*epsilon=*/1, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - LiteralUtil::CreateR1(std::vector(260, 1.0f)).get(), - LiteralUtil::CreateR1(std::vector(260, 0.0f)).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 1.0f)), + LiteralUtil::CreateR1(std::vector(260, 0.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -331,14 +328,13 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { BatchNormTraining(h0, h1, h2, /*epsilon=*/-100, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR3FromArray3D( - {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - LiteralUtil::CreateR1(std::vector(1, 15.0f)).get(), - LiteralUtil::CreateR1(std::vector(1, 125.0f)).get()}); + {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}), + LiteralUtil::CreateR1(std::vector(1, 15.0f)), + LiteralUtil::CreateR1(std::vector(1, 125.0f))}); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {operand.get(), scale.get(), offset.get()}, ErrorSpec(0.1)); } @@ -363,14 +359,13 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - LiteralUtil::CreateR1({0, 0}).get(), - LiteralUtil::CreateR1({16, 20}).get()}); + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}), + LiteralUtil::CreateR1({0, 0}), + LiteralUtil::CreateR1({16, 20})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } struct BatchNormTestParam { @@ -522,22 +517,22 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); - auto expected = LiteralUtil::MakeTuple( - {expected_normalized.get(), LiteralUtil::CreateR1(mean).get(), - LiteralUtil::CreateR1(var).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_normalized, LiteralUtil::CreateR1(mean), + LiteralUtil::CreateR1(var)}); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); @@ -547,7 +542,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, *expected, + &builder, expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } @@ -622,27 +617,27 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { auto input_literal = LiteralUtil::CreateR4FromArray4D(input_array); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); auto scale_activations = - Parameter(&builder, 1, scale_literal->shape(), "offset"); + Parameter(&builder, 1, scale_literal.shape(), "offset"); auto offset_activations = - Parameter(&builder, 2, offset_literal->shape(), "scale"); - auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean"); + Parameter(&builder, 2, offset_literal.shape(), "scale"); + auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean"); auto variance_activations = - Parameter(&builder, 4, var_literal->shape(), "variance"); + Parameter(&builder, 4, var_literal.shape(), "variance"); Array4D expected = normalized; std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr offset_data = - client_->TransferToServer(*offset_literal).ConsumeValueOrDie(); + client_->TransferToServer(offset_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr variance_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); BatchNormInference(input_activations, scale_activations, offset_activations, mean_activations, variance_activations, epsilon, @@ -811,40 +806,37 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { auto grad_output_literal = LiteralUtil::CreateR4FromArray4D(grad_output_array); - auto input_parameter = - Parameter(&builder, 0, input_literal->shape(), "input"); - auto scale_parameter = - Parameter(&builder, 1, scale_literal->shape(), "scale"); - auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean"); - auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance"); + auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input"); + auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale"); + auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean"); + auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance"); auto grad_output_parameter = - Parameter(&builder, 4, grad_output_literal->shape(), "grad_output"); + Parameter(&builder, 4, grad_output_literal.shape(), "grad_output"); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::unique_ptr scale_data = - client_->TransferToServer(*scale_literal).ConsumeValueOrDie(); + client_->TransferToServer(scale_literal).ConsumeValueOrDie(); std::unique_ptr mean_data = - client_->TransferToServer(*mean_literal).ConsumeValueOrDie(); + client_->TransferToServer(mean_literal).ConsumeValueOrDie(); std::unique_ptr var_data = - client_->TransferToServer(*var_literal).ConsumeValueOrDie(); + client_->TransferToServer(var_literal).ConsumeValueOrDie(); std::unique_ptr grad_output_data = - client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie(); + client_->TransferToServer(grad_output_literal).ConsumeValueOrDie(); BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter, grad_output_parameter, epsilon, feature_index); - auto expected = - LiteralUtil::MakeTuple({expected_grad_activation.get(), - LiteralUtil::CreateR1(grad_scale).get(), - LiteralUtil::CreateR1(grad_offset).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {expected_grad_activation, LiteralUtil::CreateR1(grad_scale), + LiteralUtil::CreateR1(grad_offset)}); // Run all HLO passes during this test. In particular, ClientLibraryTestBase // disables constant folding, but we want it enabled for our zero-sized tensor // testcase. execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, *expected, + ComputeAndCompareTuple(&builder, expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index 65589b0d6af2ffca26776541eb05a093f43e0a9a..e9728e636f0ee032416b2da17a3ea83c5bb18083 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -95,22 +95,19 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-1.6875f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, - {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) - .get(), + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}), LiteralUtil::CreateR1( - {static_cast(4), static_cast(5)}) - .get(), + {static_cast(4), static_cast(5)}), LiteralUtil::CreateR1( - {static_cast(5), static_cast(5)}) - .get()}); + {static_cast(5), static_cast(5)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -139,21 +136,18 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = LiteralUtil::MakeTuple( + auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, {{{static_cast(1.f)}, {static_cast(1.f)}}, - {{static_cast(3.f)}, {static_cast(3.f)}}}}) - .get(), + {{static_cast(3.f)}, {static_cast(3.f)}}}}), LiteralUtil::CreateR1( - {static_cast(0), static_cast(0)}) - .get(), + {static_cast(0), static_cast(0)}), LiteralUtil::CreateR1( - {static_cast(16), static_cast(20)}) - .get()}); + {static_cast(16), static_cast(20)})}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index fe4267c73bd170f22a0456533f45e50be823a80b..dde19fb65d65064c9452a6ac49c70e20cf113336 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = - client_->TransferToServer(*r3_data).ConsumeValueOrDie(); + client_->TransferToServer(r3_data).ConsumeValueOrDie(); return r3_global_data; } @@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = - client_->TransferToServer(*r2_data).ConsumeValueOrDie(); + client_->TransferToServer(r2_data).ConsumeValueOrDie(); return r2_global_data; } @@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R3ImplicitBroadcastSpec { @@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, - {r3_implicit_global_data.get(), r3_global_data.get()}, + &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()}, ErrorSpec(1e-7, 1e-7)); } @@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, + ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); auto r1 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + &b, LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3({{{1}}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3({{{1}}})); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } struct R2ImplicitBroadcastSpec { @@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( - &builder, *expected, + &builder, expected, {r2_implicit_global_data1.get(), r2_global_data.get(), r2_implicit_global_data2.get()}, ErrorSpec(1e-6, 1e-6)); @@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = - ConstantLiteral(&b, *LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1}, {2}})); + auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2({{1, 2}, {3, 4}})); Add(r2, r1); auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); auto expected = LiteralUtil::CreateR3( {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); auto expected = LiteralUtil::CreateR3( {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { @@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1(&b, {100, 200}); auto r1_2 = ConstantR1(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { @@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); - ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { @@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *LiteralUtil::CreateR3( + ConstantLiteral(&b, LiteralUtil::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 74d4d2eb10c32b270a83aa04dd2e6025d7a56c26..9966e4606ef7f104487182e0240e64e4c9e4d834 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0(42.0), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0(42.0), result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -63,7 +63,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), result, error_spec_)); } @@ -86,12 +86,12 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralSlice(*result, {0}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + LiteralSlice(result, {0}), error_spec_)); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralSlice(*result, {1}), error_spec_)); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + LiteralSlice(result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -107,7 +107,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), result, error_spec_)); } @@ -126,7 +126,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), result, error_spec_)); } @@ -143,9 +143,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_)); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,9 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -197,9 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -220,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array), + result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -240,9 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -263,9 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -295,9 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - EXPECT_TRUE( - LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(expected), - *result, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near( + LiteralUtil::CreateR4FromArray4D(expected), result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index b1d18210eaafdfec0920c0cccaa0dfdbd6de5609..8b31e53707eee456e09adfe9fb76f03a8855056d 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -77,8 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32IdentityComputation(); - auto constant = - ConstantLiteral(&builder, *LiteralUtil::CreateR0(42.0)); + auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0(42.0)); Call(&builder, callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S0F32AdditionComputation(); - auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); - auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1({})); + auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); + auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1({})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -98,9 +97,9 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({1.0f, 2.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({1.0f, 2.0f})); auto y = - ConstantLiteral(&builder, *LiteralUtil::CreateR1({2.0f, 3.0f})); + ConstantLiteral(&builder, LiteralUtil::CreateR1({2.0f, 3.0f})); Call(&builder, callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -133,7 +132,7 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr start, - client_->TransferToServer(*LiteralUtil::CreateR0(1.0f))); + client_->TransferToServer(LiteralUtil::CreateR0(1.0f))); ComputeAndCompareR0(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f)); } @@ -141,10 +140,10 @@ XLA_TEST_F(CallOpTest, CallR0F32Tuple) { XlaBuilder builder(TestName()); XlaComputation callee = CreateR0F32TupleComputation(); auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); - Call(&builder, callee, {ConstantLiteral(&builder, *elem)}); + auto tuple = LiteralUtil::MakeTuple({&elem}); + Call(&builder, callee, {ConstantLiteral(&builder, elem)}); - ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); + ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index a4eb57fc7b9abd460a7d158d0dc629eba88018cd..2f1510ff6969757f8091e9c043b61cb2a467ccd5 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -38,14 +38,14 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { XlaBuilder builder("add_two_params"); auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); - auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0"); - auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1"); + auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0"); + auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1"); Add(p0, p1); auto param0_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto param1_data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); @@ -86,12 +86,12 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { auto computation = computation_status.ConsumeValueOrDie(); auto f32_literal = LiteralUtil::CreateR0(1.1f); - auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); + auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie(); auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = - client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); + client_->TransferToServer(f32_4_literal).ConsumeValueOrDie(); auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); - auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); + auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie(); // Match auto status = client_->Execute( diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 8a236db0ff2f63332892de822461dd1cc17276ca..fbdf0fcb6543f09dedefef55cfe0f8a5d9067d5a 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -101,7 +101,7 @@ StatusOr> ClientLibraryTestBase::Execute( return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -113,7 +113,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( +StatusOr ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout) { // Build the computation, as a convenience. @@ -121,8 +121,7 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); } -StatusOr> -ClientLibraryTestBase::ExecuteAndTransferReference( +StatusOr ClientLibraryTestBase::ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout) { ExecutionOptions execution_options = execution_options_; @@ -148,15 +147,15 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return result.ValueOrDie()->ToString(); + return result.ValueOrDie().ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, absl::Span arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -182,7 +181,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( const string& error_message)>& verify_output) { // Try with no layout requirement. TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments)); - verify_output(*actual, ""); + verify_output(actual, ""); // Try with all output layouts. std::vector minor_to_major(ShapeUtil::Rank(expected.shape())); @@ -193,7 +192,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, + verify_output(actual, absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); @@ -218,9 +217,9 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_ASSIGN_OR_RETURN(auto literal, client_->Transfer(*arguments[index], nullptr)); // Skip tuples because they don't have a rank. - if (ShapeUtil::IsTuple(literal->shape())) { + if (ShapeUtil::IsTuple(literal.shape())) { layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal->shape())); + ShapeUtil::HumanStringWithLayout(literal.shape())); arguments_with_layout.push_back(arguments[index]); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -228,15 +227,15 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return Status::OK(); } - std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); + std::vector minor_to_major(ShapeUtil::Rank(literal.shape())); std::iota(minor_to_major.begin(), minor_to_major.end(), 0); do { auto literal_relayout = - literal->Relayout(LayoutUtil::MakeLayout(minor_to_major)); + literal.Relayout(LayoutUtil::MakeLayout(minor_to_major)); layout_strings.push_back( - ShapeUtil::HumanStringWithLayout(literal_relayout->shape())); + ShapeUtil::HumanStringWithLayout(literal_relayout.shape())); TF_ASSIGN_OR_RETURN(auto data, - client_->TransferToServer(*literal_relayout)); + client_->TransferToServer(literal_relayout)); arguments_with_layout.push_back(data.get()); TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); @@ -256,7 +255,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( for (const auto& str : layout_strings) { absl::StrAppend(&error_message, str, " "); } - verify_output(*actual, error_message); + verify_output(actual, error_message); return Status::OK(); }; @@ -290,11 +289,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -319,7 +318,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)); return Status::OK(); } @@ -346,11 +345,11 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( // We allow using a float expected literal for a bfloat16 output. In this // case, we need to convert the expected literal to bfloat16. const Literal* expected_ptr = &expected; - std::unique_ptr converted_expected; + Literal converted_expected; Shape layout_shape; if (use_bfloat16_) { converted_expected = LiteralUtil::ConvertF32ToBF16(expected); - expected_ptr = converted_expected.get(); + expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( @@ -376,7 +375,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)); return Status::OK(); } @@ -391,12 +390,12 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + Literal expected_literal = LiteralUtil::CreateR1U8(expected); - VLOG(1) << "expected: " << expected_literal->ToString(); - VLOG(1) << "actual: " << actual->ToString(); + VLOG(1) << "expected: " << expected_literal.ToString(); + VLOG(1) << "actual: " << actual.ToString(); - EXPECT_EQ(expected, actual->GetR1U8AsString()); + EXPECT_EQ(expected, actual.GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -408,7 +407,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -420,7 +419,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -430,9 +429,9 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(reference, result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -442,12 +441,12 @@ void ClientLibraryTestBase::ComputeAndCompare( if (!status_or_data.ok()) { return; } - std::unique_ptr reference, result; + Literal reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); + EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error)); } -StatusOr, std::unique_ptr>> +StatusOr> ClientLibraryTestBase::ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments) { // Transfer the arguments to the executor service. We put the unique_ptr's @@ -569,8 +568,8 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return ConstantLiteral(builder, use_bfloat16_ - ? *LiteralUtil::ConvertF32ToBF16(literal) - : literal); + ? LiteralUtil::ConvertF32ToBF16(literal) + : LiteralSlice(literal)); } std::unique_ptr @@ -600,7 +599,7 @@ Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( const Literal& literal) { if (use_bfloat16_) { - return std::move(*LiteralUtil::ConvertF32ToBF16(literal)); + return LiteralUtil::ConvertF32ToBF16(literal); } return literal.Clone(); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 22dfdfb0e4c67cc06fa748177c75cf35572196c8..9d32f4f5174a57a53a9d3e6477b46fa4de852f7f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -95,11 +95,11 @@ class ClientLibraryTestBase : public ::testing::Test { StatusOr> Execute( XlaBuilder* builder, absl::Span arguments); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( XlaBuilder* builder, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); - StatusOr> ExecuteAndTransfer( + StatusOr ExecuteAndTransfer( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -107,7 +107,7 @@ class ClientLibraryTestBase : public ::testing::Test { // This executes the computation via the reference client (which connects a // interpreter backend). The result is used as the expected values of the // computation. - StatusOr> ExecuteAndTransferReference( + StatusOr ExecuteAndTransferReference( const XlaComputation& computation, absl::Span arguments, const Shape* shape_with_output_layout = nullptr); @@ -282,7 +282,7 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { - return AddParam(*LiteralUtil::CreateFromArray(argument), builder); + return AddParam(LiteralUtil::CreateFromArray(argument), builder); } // Creates a constant instruction with the given literal. When the @@ -297,14 +297,14 @@ class ClientLibraryTestBase : public ::testing::Test { template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array), + return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array), builder); } // Same as CreateConstantFromArray, but for scalars. template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { - return CreateConstantFromLiteral(*LiteralUtil::CreateR0(value), + return CreateConstantFromLiteral(LiteralUtil::CreateR0(value), builder); } @@ -375,9 +375,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(XlaBuilder* builder, - absl::Span arguments); + StatusOr> ComputeValueAndReference( + XlaBuilder* builder, absl::Span arguments); Client* client_; Client* ref_client_; // To compute reference result. @@ -412,9 +411,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR0( XlaBuilder* builder, NativeT expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -428,9 +426,8 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR0(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -438,9 +435,8 @@ template void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, absl::Span expected, absl::Span arguments) { - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -454,9 +450,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + Literal expected_literal = LiteralUtil::CreateR1(expected); + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -464,9 +459,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR2( XlaBuilder* builder, const Array2D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -480,9 +475,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR2FromArray2D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -490,9 +485,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR3( XlaBuilder* builder, const Array3D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -506,9 +501,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR3FromArray3D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -516,9 +511,9 @@ template void ClientLibraryTestBase::ComputeAndCompareR4( XlaBuilder* builder, const Array4D& expected, absl::Span arguments) { - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments); } @@ -532,9 +527,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); - std::unique_ptr expected_literal = + Literal expected_literal = LiteralUtil::CreateR4FromArray4D(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, + ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal, arguments, error); } @@ -542,13 +537,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR0(value); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -556,13 +551,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR1(values); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -570,13 +565,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } @@ -584,13 +579,13 @@ template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, XlaBuilder* builder, XlaOp* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(*literal); + Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal.shape().element_type() == F32) { + literal = LiteralUtil::ConvertF32ToBF16(literal); } std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - *data_handle = Parameter(builder, parameter_number, literal->shape(), name); + client_->TransferToServer(literal).ConsumeValueOrDie(); + *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; } diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index c898dacf489db97223e2918414daf5de88bece64..6f2ca84bb646e88af221ab80b727911ff7d990eb 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -55,16 +55,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { std::unique_ptr data, client_->Execute(computation, {}, &execution_options)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); TF_ASSERT_OK_AND_ASSIGN( - auto computed, client_->Transfer(*data, &expected_literal->shape())); + auto computed, client_->Transfer(*data, &expected_literal.shape())); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } @@ -91,19 +90,19 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralSlice(*result, {0})); + LiteralSlice(result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralSlice(*result, {1})); + LiteralSlice(result, {1})); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape())); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 0), + ShapeUtil::GetTupleElementShape(result.shape(), 0), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{0, 1}))); EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::GetTupleElementShape(result->shape(), 1), + ShapeUtil::GetTupleElementShape(result.shape(), 1), ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0}))); } @@ -114,7 +113,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr const_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); + LiteralUtil::CreateR2({{5, 6}, {7, 8}}))); XlaBuilder b(TestName() + ".add"); Add(Parameter(&b, 0, shape, "param_0"), @@ -140,9 +139,9 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { TF_ASSERT_OK_AND_ASSIGN( auto result_literal, - client_->Transfer(*results[0], &expected_result->shape())); + client_->Transfer(*results[0], &expected_result.shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 03d56964998f9abea21d6f82dee8faf86f9fe1d4..6ef7ca035f75966bef12c7abcb55cb59e9b73655 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -42,14 +42,14 @@ class CompilationCacheTest : public ClientLibraryTestBase { absl::Span arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - std::unique_ptr result = + Literal result = client_ ->ExecuteAndTransfer(computation, arguments, /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR0(expected_result), *result, error_spec_)); + LiteralUtil::CreateR0(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -63,10 +63,9 @@ class CompilationCacheTest : public ClientLibraryTestBase { ->Execute(computation, arguments, &execution_options_, &execution_profile) .ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data_handle).ConsumeValueOrDie(); + Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie(); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2(expected_result), *result, error_spec_)); + LiteralUtil::CreateR2(expected_result), result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -88,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(LiteralUtil::CreateR0(456.0f)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); @@ -145,12 +144,12 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { auto rowmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = - client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(rowmaj_array).ConsumeValueOrDie(); auto colmaj_array = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = - client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); + client_->TransferToServer(colmaj_array).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 8226b6de3f780197bc0f1145b617dba99803927f..3b0414a6045a7c5f4f75948d8ccf2775c575626e 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -69,9 +69,9 @@ class ComputeConstantTest : public ::testing::Test { LOG(FATAL) << "invalid client_type value"; } - StatusOr> ComputeConstantLiteral( - Client* client, const XlaOp& operand, XlaBuilder* builder, - Layout* output_layout = nullptr) { + StatusOr ComputeConstantLiteral(Client* client, const XlaOp& operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand)); TF_ASSIGN_OR_RETURN(auto computed, client->ComputeConstant(subgraph, output_layout)); @@ -83,7 +83,7 @@ class ComputeConstantTest : public ::testing::Test { XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder, nullptr)); - return literal->Get({}); + return literal.Get({}); } bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { @@ -206,9 +206,8 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR1({4, 6}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -221,8 +220,8 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + Literal expected_literal = LiteralUtil::CreateR0(5); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } @@ -241,12 +240,11 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ConstantR2(&b, {{10, 20}, {30, 40}})), &b, &layout_proto)); - std::unique_ptr expected_literal = - LiteralUtil::CreateR2WithLayout( - {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); + Literal expected_literal = LiteralUtil::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( - expected_literal->shape(), computed->shape())); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); + expected_literal.shape(), computed.shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed)); } } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index be017477d84eb9faf5aa79dcdf54d6b6aaf6fd8e..9811a015e91d866d6f4de6ebb6dac536ed6c7e06 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -536,8 +536,8 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); auto x_literal = LiteralUtil::CreateR0(2.f); auto y_literal = LiteralUtil::CreateR0(3.f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); auto x = Parameter(&builder, 0, f32_scalar, "x"); @@ -559,12 +559,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "z"); auto bcast = Broadcast(y, {5}); @@ -587,12 +587,12 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); auto y_literal = LiteralUtil::CreateR0(1.5f); auto z_literal = LiteralUtil::CreateR0(5.5f); - auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); - auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); - auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); + auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); + auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto x = Parameter(&builder, 0, x_literal->shape(), "x"); + auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, f32_scalar, "y"); auto z = Parameter(&builder, 2, f32_scalar, "y"); auto y_bcast = Broadcast(y, {1, 5, 7}); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index 25d10ab00af11b8ebb8147917e7cdbb21f9a42c4..32cac499c7439af80bafb88ac61b0b078f589599 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -359,8 +359,8 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12.0f).get(), - LiteralUtil::CreateR0(25.0f).get()}), + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(12.0f), + LiteralUtil::CreateR0(25.0f)}), {pred_arg.get()}, error_spec_); } @@ -375,12 +375,11 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, CreateR1TupleFloorComputation()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({13.0f, 16.0f}).get(), - LiteralUtil::CreateR1({26.0f, 30.0f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({13.0f, 16.0f}), + LiteralUtil::CreateR1({26.0f, 30.0f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a tuple of a predicate, a @@ -415,13 +414,12 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands, false_builder_result.ConsumeValueOrDie()); - ComputeAndCompareTuple( - &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(true).get(), - LiteralUtil::CreateR0(12.2f).get(), - LiteralUtil::CreateR1({12.8f, 14.6f}).get()}), - {pred_arg.get()}, error_spec_); + ComputeAndCompareTuple(&builder, + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(true), + LiteralUtil::CreateR0(12.2f), + LiteralUtil::CreateR1({12.8f, 14.6f})}), + {pred_arg.get()}, error_spec_); } // Test true and false computations that return a nested tuple. @@ -463,15 +461,13 @@ XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(46.6f).get(), - LiteralUtil::CreateR1({54.4f, 58.4f}).get()}) - .get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({62.1f, 67.4f}).get(), - LiteralUtil::CreateR0(9.3f).get()}) - .get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(46.6f), + LiteralUtil::CreateR1({54.4f, 58.4f})}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({62.1f, 67.4f}), + LiteralUtil::CreateR0(9.3f)})}), {pred_arg.get()}, error_spec_); } @@ -633,8 +629,8 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { ComputeAndCompareTuple( &builder, - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(a).get(), - LiteralUtil::CreateR0(b).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(a), LiteralUtil::CreateR0(b)}), {x_arg.get(), y_arg.get()}, error_spec_); }; @@ -669,10 +665,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is true case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(true))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(true)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); @@ -682,10 +678,10 @@ XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { { // Pred is false case. std::vector args; - args.push_back(std::move( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), - LiteralUtil::CreateR0(-42).get()}))); - args.push_back(std::move(*LiteralUtil::CreateR0(false))); + args.push_back( + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0(123), + LiteralUtil::CreateR0(-42)})); + args.push_back(LiteralUtil::CreateR0(false)); XlaBuilder builder(TestName() + ".main"); auto p = Parameter(&builder, 0, tuple2, "p0"); auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 49375748319ad5fe40db507a034ec4b07adb7e84..72ff1e74a47c8584cb5336c86a1c978c4637a902 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -110,7 +110,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D( + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D( Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); @@ -126,7 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D(array3d)); + ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -140,12 +140,11 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array); { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *input_literal); + ConstantLiteral(&builder, input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,23 +158,21 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0}, {2.0}}), + LiteralUtil::CreateR1({2.0, 42})})); - std::unique_ptr result = - ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); + Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - LiteralSlice(*result, {0}), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(*result, {1}), + LiteralSlice(result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near({2.0, 42.0}, LiteralSlice(result, {1}), error_spec_); } TEST_F(ConstantsTest, Token) { XlaBuilder builder(TestName()); - ConstantLiteral(&builder, *LiteralUtil::CreateToken()); + ConstantLiteral(&builder, LiteralUtil::CreateToken()); // TODO(b/80000000): tokens cannot be returned from computations. Tuple(&builder, {}); TF_ASSERT_OK(Execute(&builder, {}).status()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 7a203d6873dbb5b69f96c50048c2c5ff3150c544..5f063e67847487f1d18bf4ee80b1634ebdf4183a 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -210,10 +210,10 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -229,10 +229,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); @@ -247,10 +247,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); @@ -264,10 +264,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -281,10 +281,10 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -318,10 +318,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr arg_literal = LiteralUtil::CreateR1({arg}); - auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); + Literal arg_literal = LiteralUtil::CreateR1({arg}); + auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = - client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); @@ -456,7 +456,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -476,7 +476,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, - client_->TransferToServer(*LiteralUtil::CreateR1(input))); + client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 38b6da4fa96b0f6b7ed2d56852eb3ab2872f3520..fd98bf29b8a06d7476d51174b61c6268750db2ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -93,8 +93,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = absl::make_unique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index d2c6478b02423c93860244bc5eb91e652a3eac2e..070b092d18930027e215cb43ff917e36cac99f12 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { {7.0f, 8.0f}, })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); // clang-format on ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } }; @@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { Array3D expected({{{570.0f, 670.0f, 770.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -435,23 +435,23 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r5).ConsumeValueOrDie(); + client_->TransferToServer(filter_r5).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r5, + ComputeAndCompareLiteral(&builder, expected_r5, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -498,23 +498,23 @@ class Convolve2D_1x3x3x5_3x3x5x3_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(92115), static_cast(93150), static_cast(94185)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -558,12 +558,12 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(16029), static_cast(16218), static_cast(16407), @@ -571,14 +571,14 @@ class Convolve2D_1x3x3x5_3x3x1x15_Depthwise_Valid : public ConvolutionTest { static_cast(18369), static_cast(18576), static_cast(18783), static_cast(19620), static_cast(19836), static_cast(20052), static_cast(20925), static_cast(21150), static_cast(21375)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -624,26 +624,26 @@ class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); iota_int_init_value(input_elems, 1); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); iota_int_init_value(filter_elems, 1); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); auto expected_r1 = LiteralUtil::CreateR1( {static_cast(5076), static_cast(5160), static_cast(5244), static_cast(5328), static_cast(6164), static_cast(6264), static_cast(6364), static_cast(6464), static_cast(7380), static_cast(7496), static_cast(7612), static_cast(7728)}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); + auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + client_->TransferToServer(input_r4).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + client_->TransferToServer(filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + ComputeAndCompareLiteral(&builder, expected_r4, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -692,8 +692,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, expected_result.Fill(0); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(param0)), - std::move(*LiteralUtil::CreateFromArray(param1))}, + {LiteralUtil::CreateFromArray(param0), + LiteralUtil::CreateFromArray(param1)}, error_spec_); } @@ -749,26 +749,25 @@ class Convolve1D1WindowTestBase std::vector input_elems(ShapeUtil::ElementsIn(input_shape), static_cast(1.0f)); auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), static_cast(1.0f)); auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); std::vector expect_elems(batch * output_feature * num_windows, static_cast(window_size * input_feature)); auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = - expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); auto input_literal = - client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + client_->TransferToServer(input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } @@ -868,8 +867,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}, + {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}, error_spec_); } @@ -891,9 +890,44 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) { Array4D filter_data(1, 1, 1, 2); filter_data.FillIota(10); - ComputeAndCompare(&builder, - {std::move(*LiteralUtil::CreateFromArray(input_data)), - std::move(*LiteralUtil::CreateFromArray(filter_data))}); + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data), + LiteralUtil::CreateFromArray(filter_data)}); +} + +XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { + XlaBuilder builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100}); + Array4D input_data(1, 64, 100, 100); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321); + Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64}); + Array4D filter_data(7, 7, 1, 64); + input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = ConstantR4FromArray4D(&builder, filter_data); + + // Specify bf01_01io->bf01 as dimension numbers. + ConvolutionDimensionNumbers dnums; + // Input + dnums.set_input_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + // Kernel + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + // Output + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + ConvGeneral(input, filter, /*window_strides=*/{1, 1}, + /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums, + /*feature_group_count=*/64); + + ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)}, + error_spec_); } class ConvolutionHloTest : public HloTestBase {}; diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 6784c16715da72d337edf70fa51db42c59404136..ba3e9c436e3cfa574a07e881a187ff4c7d6243a1 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -1335,23 +1335,23 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { auto gradients_flat = LiteralUtil::CreateR1({1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); auto weights_literal = - weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto weights = ConstantLiteral(&builder, *weights_literal); + weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto weights = ConstantLiteral(&builder, weights_literal); auto expected_flat = LiteralUtil::CreateR1({10}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = Rev(weights, {2, 3, 4}); ConvWithGeneralPadding(gradients, mirrored_weights, /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {1, 1}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { @@ -1359,17 +1359,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); auto activations_literal = - activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); - auto activations = ConstantLiteral(&builder, *activations_literal); + activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); + auto activations = ConstantLiteral(&builder, activations_literal); auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); auto gradients_literal = - gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); - auto gradients = ConstantLiteral(&builder, *gradients_literal); + gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + auto gradients = ConstantLiteral(&builder, gradients_literal); auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); auto expected_literal = - expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = ConvGeneralDilated(activations, gradients, @@ -1379,7 +1379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); Transpose(forward_conv, {0, 1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); + ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 526626c1ddd902a4ba6c608f2b9355cece9ec833..1407e68d9a336b6bb1c960711015430f872aa912 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,16 +40,16 @@ class CopyOpTest : public HloTestBase { protected: void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(literal.CloneToUnique())); + auto constant = + builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); + Literal result = ExecuteAndTransfer(std::move(module), {}); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -58,31 +58,30 @@ class CopyOpTest : public HloTestBase { }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); + TestCopyOp(LiteralUtil::CreateR0(true)); } XLA_TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); + TestCopyOp(LiteralUtil::CreateR1({})); } XLA_TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(LiteralUtil::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } XLA_TEST_F(CopyOpTest, CopyParameterScalar) { @@ -90,7 +89,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = LiteralUtil::CreateR0(42.0); - Shape shape = literal->shape(); + Shape shape = literal.shape(); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -102,9 +101,8 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {literal.get()}); - LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {&literal}); + LiteralTestUtil::ExpectR0Near(42.0f, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { @@ -123,19 +121,17 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, *result, + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR2Near({{1.0, 2.0}, {3.0, 4.0}}, result, error_spec_); } XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = - literal->mutable_shape_do_not_use()->mutable_layout(); + Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); @@ -149,11 +145,11 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); // The result of the computation has the default layout, which is the inverse // of the layout of the source literal. - LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, *result, + LiteralTestUtil::ExpectR2Near({{1.0, 3.0}, {2.0, 4.0}}, result, error_spec_); } @@ -169,7 +165,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + Literal literal = LiteralUtil::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -182,9 +178,9 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0})); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR3EqualArray3D(a, *result); + LiteralTestUtil::ExpectR3EqualArray3D(a, result); } void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, @@ -203,7 +199,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + Literal literal = LiteralUtil::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -216,9 +212,9 @@ void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, auto module = CreateNewModule(); module->AddEntryComputation(std::move(computation)); ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation)); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR4EqualArray4D(a, *result); + LiteralTestUtil::ExpectR4EqualArray4D(a, result); } XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) { @@ -250,11 +246,11 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { XlaBuilder builder(TestName()); Parameter(&builder, 0, in_shape, "input"); - auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); + auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index d12a4e7fcd7813775a81677bcaa07af60ff9b477..410732c07b7b6d3ece33ab11f4778241dc53ca50 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -46,7 +46,7 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = LiteralUtil::CreateR1({1, 2, 3}); - EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); + EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal})); } XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { @@ -68,9 +68,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ( - *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0, &literal1})); } // On the GPU backend, constants get special handling. Someone might pass a @@ -95,8 +94,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = LiteralUtil::CreateR1({1, 2, 3}); auto literal1 = LiteralUtil::CreateR1({10, 20}); - EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}), - *ExecuteAndTransfer(std::move(module), {literal0.get()})); + EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}), + ExecuteAndTransfer(std::move(module), {&literal0})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 6f7fc0e6e52a69387a4c491871b6fcd97ac638b6..a693fa35954bcb2d95074c94d0aa3eabc1d5fd62 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -80,8 +80,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(44.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(44.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { @@ -101,8 +101,8 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { module->AddEntryComputation(builder.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectR0Near(10.0f, *result, error_spec_); + Literal result = ExecuteAndTransfer(std::move(module), {}); + LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } XLA_TEST_F(CustomCallTest, @@ -125,9 +125,9 @@ XLA_TEST_F(CustomCallTest, module->AddEntryComputation(b.Build()); - std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); + Literal result = ExecuteAndTransfer(std::move(module), {}); LiteralTestUtil::ExpectR3EqualArray3D( - Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); + Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } class CustomCallClientAPITest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index eb15fc0593adf2d1bd84da4d0f708b6244f0fb33..e0f23b0fa807ca27038afa2eec5f739508e3d5bd 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -64,11 +64,11 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { // Try copying the elements back and comparing it auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { @@ -86,19 +86,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { auto handles1 = result_status1.ConsumeValueOrDie(); auto handles2 = result_status2.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); handles1[0].reset(); handles1[1].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { @@ -116,15 +116,15 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { // the same as handle[3] and handle[1] should be the same as handle[2]. auto handles = result_status.ConsumeValueOrDie(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { @@ -142,19 +142,19 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { // should not have been deallocated because of reference counting. global_data.reset(); - std::unique_ptr literal; + Literal literal; TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1])); - LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, *literal); + LiteralTestUtil::ExpectR1Equal({2.0, 4.0, 6.0, 8.0}, literal); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); /// Try deallocating one of the repeated elements, then copy handles[0].reset(); TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2])); - LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, *literal); + LiteralTestUtil::ExpectR1Equal({1.0, 2.0, 3.0, 4.0}, literal); } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { @@ -170,10 +170,9 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); Tuple(&builder, {p}); auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 5873516442fa63de47360acaa353abb3a97fe881..0171f515839d556827f0723772214d175939d386 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -68,16 +68,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, - *LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1, 2}, {3, 4}}).get(), - LiteralUtil::CreateR2({{5, 6}, {7, 8}}).get()}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1, 2}, {3, 4}}), + LiteralUtil::CreateR2({{5, 6}, {7, 8}})}), "arg0", &builder, ¶m); auto lhs = GetTupleElement(param, 0); auto rhs = GetTupleElement(param, 1); Dot(lhs, rhs); ComputeAndCompareLiteral(&builder, - *LiteralUtil::CreateR2({{19, 22}, {43, 50}}), + LiteralUtil::CreateR2({{19, 22}, {43, 50}}), {param_data.get()}); } @@ -196,11 +196,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, FusedDot) { auto lhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}})) .ConsumeValueOrDie(); auto rhs_handle = this->client_ - ->TransferToServer(*LiteralUtil::CreateR2FromArray2D( + ->TransferToServer(LiteralUtil::CreateR2FromArray2D( {{1.0f}, {2.0f}, {3.0f}, {4.0f}})) .ConsumeValueOrDie(); @@ -219,14 +219,14 @@ class SquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f}, {3.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -286,24 +286,23 @@ void ParametricDotTest::TestImpl() { std::unique_ptr> dot_lhs_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); - std::unique_ptr dot_lhs_lit = - LiteralUtil::CreateR2FromArray2DWithLayout( - *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor( - param.dot_lhs_row_major))); + Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = - client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie(); std::unique_ptr> dot_rhs_data = MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); Layout rhs_layout = LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); - std::unique_ptr dot_rhs_lit = + Literal dot_rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = - client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie(); std::unique_ptr> addend_data; - std::unique_ptr addend_lit; + Literal addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { @@ -311,7 +310,7 @@ void ParametricDotTest::TestImpl() { addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); - addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie(); } XlaBuilder builder(TestName()); @@ -477,14 +476,14 @@ class NonsquareMatrixDot : public DotOperationTest { void TestImpl(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout( + ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout( {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}}, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) @@ -511,12 +510,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); } XLA_TEST_F(DotOperationTest, MatrixVectorC64) { auto lhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*LiteralUtil::CreateR2WithLayout( + ->TransferToServer(LiteralUtil::CreateR2WithLayout( {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); @@ -584,7 +583,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1000.0f, 100.0f}, {10.0f, 1.0f}}, {{2000.0f, 200.0f}, {20.0f, 2.0f}}}, {{{3000.0f, 300.0f}, {30.0f, 3.0f}}, @@ -592,7 +591,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{11.0f, 22.0f}, {33.0f, 44.0f}}, {{55.0f, 66.0f}, {77.0f, 88.0f}}}})) @@ -630,13 +629,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}})) .ConsumeValueOrDie(); auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR3FromArray3D( + ->TransferToServer(LiteralUtil::CreateR3FromArray3D( {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}})) .ConsumeValueOrDie(); @@ -668,7 +667,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto x_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, {{{9.0f, 10.0f}, {11.0f, 12.0f}}, {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) @@ -676,7 +675,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { auto y_data = this->client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D( + ->TransferToServer(LiteralUtil::CreateR4FromArray4D( {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) .ConsumeValueOrDie(); @@ -708,14 +707,14 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TransposeFolding) { auto lhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = this->client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + LiteralUtil::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -778,15 +777,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{53.0f, 74.0f}, {45.0f, 66.0f}}); this->template ComputeAndCompareR2( @@ -827,15 +826,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, TF_ASSERT_OK_AND_ASSIGN( auto arg_0_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_0_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_1_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_1_value_array))); TF_ASSERT_OK_AND_ASSIGN( auto arg_2_value, this->client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); + LiteralUtil::CreateR2FromArray2D(*arg_2_value_array))); Array2D expected({{38.0f, 36.0f}, {93.0f, 91.0f}}); this->template ComputeAndCompareR2( diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 9bf3767ca3e229cd3eb37c1f51c526c7dd2bf0f8..7501c6d957e7afe99b8c530e5f0d575f818367da 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -124,13 +124,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { // vector is special so that it cannot be a Span, which // is what the code below wants. So instead we do this. Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie(); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -150,13 +150,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -176,13 +176,13 @@ class DynamicSliceTest : public ClientLibraryTestBase { const std::vector& slice_sizes, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -359,17 +359,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { void RunR0(int input_value_int, int update_value_int, const std::vector slice_starts, int expected_value_int) { Literal input_value = - std::move(*LiteralUtil::CreateR0(input_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(input_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_value = - std::move(*LiteralUtil::CreateR0(update_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(update_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_value = - std::move(*LiteralUtil::CreateR0(expected_value_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR0(expected_value_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -390,17 +390,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, absl::Span expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR1(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR1(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR1(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR1(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -421,17 +421,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array2D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -452,17 +452,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { const std::vector slice_starts, const Array3D& expected_values_int) { Literal input_values = - std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(input_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal update_values = - std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(update_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); Literal expected_values = - std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int) - ->Convert(primitive_util::NativeToPrimitiveType()) - .ValueOrDie()); + std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int) + .Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. @@ -529,9 +529,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { - std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << literal->ToString(); + Literal literal = LiteralUtil::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal.ToString(); } }; @@ -719,7 +718,7 @@ void BM_DynamicSlice(int num_iters) { auto input_literal = LiteralUtil::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - auto input = ConstantLiteral(&builder, *input_literal); + auto input = ConstantLiteral(&builder, input_literal); // Create dynamic slice start indices as a parameter: shape [4] auto start_indices_shape = ShapeUtil::MakeShape(S32, {4}); @@ -740,7 +739,7 @@ void BM_DynamicSlice(int num_iters) { auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - stream.get(), *start_indices_literal, buffer)); + stream.get(), start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 5116e60ca63ef5f94b25b15e6616086fb9e44bbb..b08ece0e63e9472f657b49b533511e9b192d3212 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -31,7 +31,7 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr input, client_->TransferToServer( - *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256))); XlaBuilder b(TestName() + ".add"); Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1")); diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index bf1de02ba9dbd97db9ee31484402fe9b92385219..51b50d456e496c9c01c38fb8539bb3737de16937 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -38,29 +38,29 @@ class ExhaustiveF32ElementwiseOpTest XlaBuilder builder(TestName()); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateFromDimensions(F32, {input_size}); for (int64 i = begin; i < end; i++) { if (i >= known_incorrect_range.first && i < known_incorrect_range.second) { // If the operation is known to be buggy on a specific input clamp that // input to 0 under the assumption that the op is at least correct on 0. - input_literal->Set({i - begin}, 0.0f); + input_literal.Set({i - begin}, 0.0f); } else { - input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + input_literal.Set({i - begin}, tensorflow::bit_cast(i)); } } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); + client_->TransferToServer(input_literal)); - auto input = Parameter(&builder, 0, input_literal->shape(), "input"); + auto input = Parameter(&builder, 0, input_literal.shape(), "input"); enqueue_op(&builder, input); std::vector expected_result; expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { - expected_result.push_back(evaluate_op(input_literal->Get({i}))); + expected_result.push_back(evaluate_op(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 7cb2f0cedfc2e74386bb3c01ca0b838e7cdcbce9..9c94acb437e9fc948a4255f7112e2e7a40cfa5fb 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -117,9 +117,9 @@ class FusionTest : public HloTestBase { auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); + EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4))); } else { - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } } @@ -222,8 +222,8 @@ XLA_TEST_F(FusionTest, Test) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.5}, {2.72}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -248,8 +248,8 @@ XLA_TEST_F(FusionTest, Parameter) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -283,7 +283,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { // Every element of result should be y = x^2 = 4.0. for (int i = 0; i < rand_dim0_size; ++i) { for (int j = 0; j < dim1_size; ++j) { - EXPECT_EQ(4.0, result->Get({i, j})); + EXPECT_EQ(4.0, result.Get({i, j})); } } } @@ -308,8 +308,8 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Near( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); + LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -323,8 +323,8 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(5), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -338,8 +338,8 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -353,8 +353,8 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -368,8 +368,8 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -383,8 +383,8 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR3({{{7}}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -398,8 +398,8 @@ XLA_TEST_F(FusionTest, Reshape__) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(7), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -413,8 +413,8 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -428,8 +428,8 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -443,8 +443,8 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -459,8 +459,8 @@ XLA_TEST_F(FusionTest, Reverse) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({3, 2, 1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -477,8 +477,8 @@ XLA_TEST_F(FusionTest, ReverseNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-3, -2, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -495,8 +495,8 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -1}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -513,8 +513,8 @@ XLA_TEST_F(FusionTest, SliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-1, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -535,8 +535,8 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({-2, -3}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -552,9 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -2}, {-3, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, TransposeNegate) { @@ -570,9 +570,9 @@ XLA_TEST_F(FusionTest, TransposeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + EXPECT_TRUE( + LiteralTestUtil::Equal(LiteralUtil::CreateR2({{-1, -3}, {-2, -4}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -602,8 +602,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { HloInstruction::FusionKind::kInput); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -624,8 +624,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR0(-15), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -674,8 +674,8 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -710,8 +710,8 @@ XLA_TEST_F(FusionTest, SharedConstant) { EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); EXPECT_TRUE( - LiteralTestUtil::Equal(*LiteralUtil::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {}))); + LiteralTestUtil::Equal(LiteralUtil::CreateR1({8}), + ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -782,19 +782,17 @@ ENTRY main { } )"; - std::unique_ptr operand = - LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); + Literal operand = LiteralUtil::CreateR2({{0., 0.}, {1., 0.}}); HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, - test_runner_.Execute(std::move(module), {operand.get()}, - /*run_hlo_passes=*/false)); + TF_ASSERT_OK_AND_ASSIGN(Literal result, + test_runner_.Execute(std::move(module), {&operand}, + /*run_hlo_passes=*/false)); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), - *result)); + LiteralUtil::CreateR3({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), + result)); } class FusionClientLibraryTest : public ClientLibraryTestBase {}; @@ -821,16 +819,16 @@ XLA_TEST_F(FusionClientLibraryTest, ManyLayoutTransformations) { // where overflow is OK. Array2D arr(32, 32); arr.FillUnique(); - std::unique_ptr l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({0, 1})); - std::unique_ptr l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout( + Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout( LayoutUtil::MakeLayout({1, 0})); - XlaOp p0 = AddParam(*l1, &b); + XlaOp p0 = AddParam(l1, &b); XlaOp sum = p0; for (int i = 1; i < kNumParams; ++i) { - auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b); + auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b); sum = sum + p0 * pN * pN; } @@ -879,19 +877,19 @@ void BM_ParallelFusion(int num_iters) { auto param0_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); ScopedShapedBuffer buffer0 = - client->LiteralToShapedBuffer(*param0_literal, device_ordinal) + client->LiteralToShapedBuffer(param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); ScopedShapedBuffer buffer1 = - client->LiteralToShapedBuffer(*param1_literal, device_ordinal) + client->LiteralToShapedBuffer(param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); ScopedShapedBuffer buffer2 = - client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + client->LiteralToShapedBuffer(param2_literal, device_ordinal) .ConsumeValueOrDie(); // Build executable. diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 6d634980449268e509d87ee064fbaaaf59abd195..daa89398a697af9149797d621c3bdca80a00aedd 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -58,10 +58,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) { @@ -79,10 +79,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) { @@ -100,11 +100,10 @@ ENTRY main { slice_sizes={3, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) { @@ -122,11 +121,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) { @@ -144,11 +143,11 @@ ENTRY main { slice_sizes={1, 1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) { @@ -166,13 +165,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) { @@ -190,13 +188,12 @@ ENTRY main { slice_sizes={1,1,2} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, DynamicSlice) { @@ -214,10 +211,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) { @@ -235,11 +232,10 @@ ENTRY main { slice_sizes={1,1} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroDimBounds) { @@ -257,9 +253,9 @@ ENTRY main { slice_sizes={1, 0} } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) { @@ -281,11 +277,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) { @@ -307,11 +303,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndex) { @@ -333,11 +329,11 @@ ENTRY main { ROOT result = s32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) { @@ -359,11 +355,11 @@ ENTRY main { ROOT result = u32[6]{0} reshape(gather) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR2( + Literal start_indices = LiteralUtil::CreateR2( {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, OneScalarIndex) { @@ -381,10 +377,10 @@ ENTRY main { slice_sizes={1,3,2} } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ScalarResult) { @@ -402,9 +398,9 @@ ENTRY main { slice_sizes={1} } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr start_indices = LiteralUtil::CreateR0(1); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal start_indices = LiteralUtil::CreateR0(1); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, ZeroSizedResult) { @@ -422,10 +418,10 @@ ENTRY main { slice_sizes={1, 3} } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { @@ -446,10 +442,10 @@ ENTRY main { ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({0, 2}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({0, 2}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { @@ -470,11 +466,10 @@ ENTRY main { ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { @@ -495,11 +490,11 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = + Literal start_indices = LiteralUtil::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { @@ -520,13 +515,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, @@ -548,13 +542,12 @@ ENTRY main { ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { @@ -575,10 +568,10 @@ ENTRY main { ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = LiteralUtil::CreateR1({1, 1}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR1({1, 1}); + RunTest(hlo_text, &operand, &start_indices); } XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { @@ -599,11 +592,10 @@ ENTRY main { ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr start_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - RunTest(hlo_text, operand.get(), start_indices.get()); + Literal start_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, &operand, &start_indices); } class GatherClientLibraryTest : public ClientLibraryTestBase {}; @@ -640,10 +632,10 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr operand_arg, client_->TransferToServer( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr indices_arg, - client_->TransferToServer(*LiteralUtil::CreateR1({0, 2}))); + client_->TransferToServer(LiteralUtil::CreateR1({0, 2}))); TF_ASSERT_OK_AND_ASSIGN(std::vector devices, client_->GetDeviceHandles(1)); xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions(); @@ -657,10 +649,9 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { TF_ASSERT_OK_AND_ASSIGN( std::vector> result_data, client_->ExecuteParallel(computation_instances)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, - *result_literal); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index fc4c68246e62a4baa7a506ec37886102c35c4b3b..bdd4fd7e3d0f585d81e94a3326e6d24bb5c42f39 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -120,6 +120,14 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, return status_or; } +/* static */ +PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfig::DEFAULT); + return precision_config; +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. @@ -128,21 +136,21 @@ DebugOptions HloTestBase::GetDebugOptionsForTest() { return debug_options; } -StatusOr> HloTestBase::Execute( - std::unique_ptr module, absl::Span arguments) { +StatusOr HloTestBase::Execute(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments); } -std::unique_ptr HloTestBase::ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments) { return test_runner_ .Execute(std::move(module), arguments, /*run_hlo_passes=*/false) .ValueOrDie(); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments) { +Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments) { return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } @@ -180,7 +188,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( TF_ASSIGN_OR_RETURN(auto reference, reference_runner_.Execute(std::move(reference_module), arguments, run_hlo_passes)); - return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, error); } @@ -215,13 +223,12 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( ::testing::AssertionResult HloTestBase::RunAndCompare( std::unique_ptr module, const optional& error, const std::function& reference_preprocessor) { - const auto& fake_arguments = - MakeFakeArguments(module.get()).ConsumeValueOrDie(); + auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie(); std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompare(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -235,7 +242,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, reference_preprocessor); @@ -269,7 +276,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( std::vector fake_argument_ptrs; absl::c_transform( fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const std::unique_ptr& literal) { return literal.get(); }); + [](const Literal& literal) { return const_cast(&literal); }); return test_runner_ .Execute(std::move(module_or_status.ValueOrDie()), fake_argument_ptrs, /*run_hlo_passes=*/true) diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4c88257bb27f5504588bba3ee0b14ac53c971225..0ae4bdc104d656946d45008adec9ea3960984545 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -80,6 +80,8 @@ class HloTestBase : public ::testing::Test { static StatusOr RunHloPass(HloPassInterface* hlo_pass, HloModule* module); + static PrecisionConfig DefaultPrecisionConfig(int operands); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -113,16 +115,16 @@ class HloTestBase : public ::testing::Test { } // Executes the given module and return the result as a Literal. - StatusOr> Execute( - std::unique_ptr module, absl::Span arguments); + StatusOr Execute(std::unique_ptr module, + absl::Span arguments); // Same as above, except the module will be executed without running any HLO // passes on it. - std::unique_ptr ExecuteNoHloPasses( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); - std::unique_ptr ExecuteAndTransfer( - std::unique_ptr module, absl::Span arguments); + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); // Executes the given hlo module on two backends and compares results. // diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 96f72212f35f5e6e98e2dc24fd9a87891a326e8f..43cca91f64b2c0fbfde5054a361cf0f95302c23d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -155,20 +155,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR0(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( absl::Span expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR1(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2(expected), actual)); } template @@ -176,46 +176,46 @@ template std::initializer_list>> expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const LiteralSlice& actual) { - EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual)); + EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR0(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( absl::Span expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR1(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2(expected), actual, error)); } template @@ -223,7 +223,7 @@ template std::initializer_list>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3(expected), actual, error)); } template @@ -232,28 +232,28 @@ template std::initializer_list>>> expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error)); + EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 4151bfae0332ffc706ba730d181c487eabab856f..b6f9b8156b51144e4f74d285b1e4111d098f13c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,11 +31,11 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal literal = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -43,15 +43,15 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + Literal lhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(42), + LiteralUtil::CreateR0(64), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + Literal rhs = LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR0(64), + LiteralUtil::CreateR0(42), }); - CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; + CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal"; }; ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal"); } @@ -61,7 +61,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto two = LiteralUtil::CreateR0(2); auto four = LiteralUtil::CreateR0(4); ErrorSpec error(0.001); - CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; + CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four"; }; tensorflow::Env* env = tensorflow::Env::Default(); @@ -86,14 +86,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal->ToString()); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal->ToString()); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("mismatches") != string::npos) { - EXPECT_EQ("true", literal->ToString()); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } @@ -103,8 +103,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto expected = LiteralUtil::CreateR1({1, 2, 3}); auto actual = LiteralUtil::CreateR1({4, 5, 6}); - ::testing::AssertionResult result = - LiteralTestUtil::Equal(*expected, *actual); + ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual); EXPECT_THAT(result.message(), ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); EXPECT_THAT(result.message(), @@ -116,7 +115,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtilTest, NearComparatorR1Nan) { @@ -124,7 +123,7 @@ TEST(LiteralTestUtilTest, NearComparatorR1Nan) { {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1( {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); - EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); } TEST(LiteralTestUtil, NearComparatorDifferentLengths) { @@ -132,8 +131,8 @@ TEST(LiteralTestUtil, NearComparatorDifferentLengths) { {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); auto b = LiteralUtil::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); - EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); - EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001})); } } // namespace diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 237a4a361e386e24c2897c42602eb60ca7234731..dbdd20daf0c3a54ed7b6e2a9d3fb73274d77474a 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -45,7 +45,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { DefaultExecutableBuildOptions(), options); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_); // At least one allocation should have been performed when executing the // computation. @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 1a823cf189b310c62c735419936544ea99fcfbaf..a99b43f4690b3063f76e2cda1e58c9b4ba9a1df4 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -58,7 +58,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(123.f, ShapedBufferToLiteral(result), error_spec_); } @@ -68,10 +68,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = ConstantR0(&builder, 123.0f); Add(x, y); - auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0(42.0f)); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR0Near(165.f, ShapedBufferToLiteral(result), error_spec_); } @@ -81,10 +81,10 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = ConstantR1(&builder, {}); Add(x, y); - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1({})); + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1({})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), + LiteralTestUtil::ExpectR1Near({}, ShapedBufferToLiteral(result), error_spec_); } @@ -95,11 +95,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { @@ -109,14 +109,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; ScopedShapedBuffer result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } @@ -128,13 +128,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout( + auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -142,15 +142,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. ScopedShapedBuffer result_param_swap = ExecuteLocallyOrDie(computation, {&y_array, &x_array}); - LiteralTestUtil::ExpectR2Near( - {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_param_swap), error_spec_); + LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, + ShapedBufferToLiteral(result_param_swap), + error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { @@ -161,9 +161,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( @@ -174,7 +174,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_colmaj), + ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. @@ -186,7 +186,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(result_rowmaj), + ShapedBufferToLiteral(result_rowmaj), error_spec_); } @@ -198,9 +198,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -208,13 +208,13 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {2})); + LiteralSlice(result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -226,9 +226,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto computation = builder.Build().ConsumeValueOrDie(); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); auto y_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); + LiteralUtil::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_array, &y_array}); @@ -236,15 +236,15 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); + LiteralSlice(result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); + LiteralSlice(result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralSlice(result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -255,7 +255,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { Tuple(&builder, {x, y}); auto array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); + LiteralUtil::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); Shape shape_with_layout = ShapeUtil::MakeTupleShape( @@ -268,11 +268,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,15 +298,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { Tuple(&builder, {array_sum, vector_diff}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}); - auto y_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({2.0, 4.0, 6.0}).get(), - LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + auto x_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}); + auto y_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({2.0, 4.0, 6.0}), + LiteralUtil::CreateR2({{55.0, 44.0}, {33.0, 22.0}})}); - auto x_buffer = LiteralToShapedBuffer(*x_literal); - auto y_buffer = LiteralToShapedBuffer(*y_literal); + auto x_buffer = LiteralToShapedBuffer(x_literal); + auto y_buffer = LiteralToShapedBuffer(y_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); @@ -314,11 +314,11 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -344,21 +344,20 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Tuple(&builder, {negate_array, vector_sum}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR1({42.0, 75.0, 123.0}).get()}) - .get(), - LiteralUtil::CreateR1({222.0, -2.0, 10.0}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR1({42.0, 75.0, 123.0})}), + LiteralUtil::CreateR1({222.0, -2.0, 10.0})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - LiteralSlice(*result_literal, {0})); + LiteralSlice(result_literal, {0})); LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - LiteralSlice(*result_literal, {1})); + LiteralSlice(result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -377,24 +376,24 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Tuple(&builder, {Neg(element_0), Add(element_1, element_1)}); auto computation = builder.Build().ConsumeValueOrDie(); - auto arg_literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), - LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + LiteralUtil::CreateR2({{11.0, 3.0}, {4.0, 5.0}})}); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); + Literal result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); + LiteralSlice(result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - LiteralSlice(*result_0_literal, {1})); + LiteralSlice(result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); + Literal result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - LiteralSlice(*result_1_literal, {0})); + LiteralSlice(result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - LiteralSlice(*result_1_literal, {1})); + LiteralSlice(result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -427,20 +426,19 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { // Feed in a tuple where each two-element vector element is {tuple_index, // -tuple_index}. - std::vector> arg_elements; + std::vector arg_elements; for (int i = 0; i < kElementCount; ++i) { arg_elements.push_back(LiteralUtil::CreateR1({1.0f * i, -1.0f * i})); } - std::unique_ptr arg_literal = - LiteralUtil::MakeTupleOwned(std::move(arg_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_); } } @@ -476,9 +474,9 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::vector> outer_tuple_elements; + std::vector outer_tuple_elements; for (int i = 0; i < kFanout; ++i) { - std::vector> inner_tuple_elements; + std::vector inner_tuple_elements; for (int j = 0; j < kFanout; ++j) { inner_tuple_elements.push_back(LiteralUtil::CreateR0(i + j)); } @@ -487,16 +485,16 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { } auto arg_literal = LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements)); - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { - LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), - error_spec_); + LiteralTestUtil::ExpectR0Near(i + j + i * kFanout + j, + LiteralSlice(result_literal, {i, j}), + error_spec_); } } } @@ -525,23 +523,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { auto computation = builder.Build().ConsumeValueOrDie(); // Construct the argument to pass to the computation. - std::unique_ptr arg_literal = LiteralUtil::CreateR0(123.0); + Literal arg_literal = LiteralUtil::CreateR0(123.0); for (int i = 0; i < kTupleDepth; ++i) { - std::vector> arg_vector; + std::vector arg_vector; arg_vector.push_back(std::move(arg_literal)); arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector)); } - auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + auto arg_buffer = LiteralToShapedBuffer(arg_literal); ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(result); + Literal result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal(165.0, - LiteralSlice(*result_literal, index)); + LiteralSlice(result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -552,7 +550,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { Add(x, y); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -568,7 +566,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); @@ -585,7 +583,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { Neg(x); auto x_array = LiteralToShapedBuffer( - *LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); + LiteralUtil::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( @@ -622,7 +620,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { DefaultExecutableRunOptions().set_device_ordinal(d)); EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + ShapedBufferToLiteral(result)); } } } @@ -666,8 +664,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. EXPECT_EQ(d, result.device_ordinal()); - LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(result)); + LiteralTestUtil::ExpectR0Equal(42.0f, ShapedBufferToLiteral(result)); } } @@ -745,11 +742,11 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); + Literal tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - LiteralSlice(*tuple_literal, {0})); + LiteralSlice(tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - LiteralSlice(*tuple_literal, {1})); + LiteralSlice(tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -768,7 +765,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { executable_status.ConsumeValueOrDie(); auto x_array = - LiteralToShapedBuffer(*LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); + LiteralToShapedBuffer(LiteralUtil::CreateR1({0.0f, 1.0f, 2.0f})); ScopedShapedBuffer result = executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); @@ -778,7 +775,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { ->BlockHostUntilDone()); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); + {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -792,33 +789,33 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; // Array shapes. - test_to_device_and_back(*LiteralUtil::CreateR0(42.0)); - test_to_device_and_back(*LiteralUtil::CreateR0(true)); - test_to_device_and_back(*LiteralUtil::CreateR1({1.0, 42.0, 744.4})); + test_to_device_and_back(LiteralUtil::CreateR0(42.0)); + test_to_device_and_back(LiteralUtil::CreateR0(true)); + test_to_device_and_back(LiteralUtil::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). - test_to_device_and_back(*LiteralUtil::MakeTuple({})); + test_to_device_and_back(LiteralUtil::MakeTuple({})); // Non-nested tuples. - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(12223.0).get()})); - test_to_device_and_back( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(12223.0)})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)})); // Nested tuple. - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456.0).get()}) - .get(), - LiteralUtil::CreateR0(false).get()})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456.0)}), + LiteralUtil::CreateR0(false)})); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { @@ -832,17 +829,17 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(shaped_buffer)); - EXPECT_EQ(literal, *transferred_literal); + EXPECT_EQ(literal, transferred_literal); }; test_to_device_and_back( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); - test_to_device_and_back(*LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); + LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(LiteralUtil::CreateR2({{2, 1}, {4444, 56}})); test_to_device_and_back( - *LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); - test_to_device_and_back(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({1.0, -42.0}).get(), - LiteralUtil::CreateR0(123456789000LL).get()})); + LiteralUtil::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({1.0, -42.0}), + LiteralUtil::CreateR0(123456789000LL)})); } XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { @@ -852,7 +849,7 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { auto constant = ConstantR1(&builder, {1.0f, 2.0f, 3.0f}); Add(in, constant); - std::unique_ptr result; + Literal result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -861,13 +858,13 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedTest) { })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); // Join the thread. thread.reset(); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { @@ -884,14 +881,14 @@ XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) { [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); ASSERT_IS_OK(local_client_->TransferToInfeedLocal( - *LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), + LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + TF_ASSERT_OK_AND_ASSIGN(Literal result, local_client_->TransferFromOutfeedLocal( shape, local_client_->default_device_ordinal())); - LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } // Benchmark that measures the overhead of the LocalClient API when running a @@ -922,8 +919,8 @@ void BM_LocalClientOverhead(int num_iters) { auto literal = LiteralUtil::CreateR2({{0, 0, 0}, {0, 0, 0}}); auto stream = client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, - buffer)); + ASSERT_IS_OK( + transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer)); const int kWarmups = 2; diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index a8c68fc7fdbad30068af44606f559ca96603fe66..f90ef22d2d549f451f8af231aea834e9f097b12a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -136,7 +136,7 @@ ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( +Literal LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { return local_client_->ShapedBufferToLiteral(shaped_buffer) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 90095c5d410f1561a1303a0f62f44d22ed5340f9..4027c7b124f8ac6e4b94600871ac32376a3e6467 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -86,8 +86,7 @@ class LocalClientTestBase : public ::testing::Test { // Construct and return a literal containing the array represented by // shaped_buffer. - std::unique_ptr ShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer); + Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer); // Execute the given computation on the local client. With and without // options. diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 0732e195d44d738b264361e43d38259c26a4116e..4d327a6fe9c45174a0666fd573a081e0cfe450d2 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -169,11 +169,11 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + Literal param0_literal = LiteralUtil::CreateR0(42.0); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, @@ -183,11 +183,11 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, @@ -197,12 +197,12 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, @@ -211,12 +211,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -224,12 +224,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); @@ -238,12 +238,12 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( @@ -255,11 +255,11 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -271,12 +271,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then // maps (lambda (x) (* x 2)) on the result. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0}); Map(&builder, {map1}, CreateMulByTwo(), {0}); @@ -287,12 +287,12 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( @@ -342,17 +342,17 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder), {0}); @@ -365,18 +365,18 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2WithLayout( + Literal param0_literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = LiteralUtil::CreateR2WithLayout( + Literal param1_literal = LiteralUtil::CreateR2WithLayout( {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1}); @@ -391,18 +391,18 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder), {0, 1, 2}); @@ -413,22 +413,22 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - std::unique_ptr param2_literal = + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); + Literal param2_literal = LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = - client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); + client_->TransferToServer(param2_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); - auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); + auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2"); Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( @@ -475,17 +475,17 @@ TEST_F(MapTest, MapOperantionWithBuildError) { Add(x, y); auto error_add = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); + Literal param1_literal = LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); @@ -513,15 +513,15 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, @@ -540,15 +540,15 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + Literal param0_literal = LiteralUtil::CreateR0(2.0f); + Literal param1_literal = LiteralUtil::CreateR0(5.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + client_->TransferToServer(param1_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); - auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); + auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Map(&builder, {param0, param1}, sub_opposite, {}); ComputeAndCompareR0( @@ -565,11 +565,11 @@ TEST_F(MapTestWithFullOpt, MapSquare) { Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + Literal param0_literal = LiteralUtil::CreateR0(10.0f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); Map(&builder, {param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index edb592f43ec778a3fe6e5ef936827dd612791760..3f278115e078877de1683574370df7790c2801fd 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -63,11 +63,11 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { }); Exp(data); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -92,10 +92,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { }); Map(&builder, {data}, add_half, {0, 1}); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -111,10 +111,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { }); Max(lhs, rhs); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -200,14 +200,12 @@ class MatOpsDotAddTest TF_ASSERT_OK_AND_ASSIGN( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSERT_OK_AND_ASSIGN( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); XlaBuilder builder(TestName()); auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs"); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 05f90ba9fb7d781f64bd52008423f603397ce628..56aaeb0e6878737e6c689e8065d8f1e1871b3472 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -47,7 +47,6 @@ limitations under the License. namespace xla { namespace { - class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } @@ -90,8 +89,8 @@ class MultiOutputFusionTest : public HloTestBase { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -115,10 +114,10 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); + Literal literal_r0 = LiteralUtil::CreateR0(-9.0f); auto actual = - ExecuteAndTransfer(std::move(hlo_module), - {LiteralUtil::CreateR0(-9.0f).get(), &arg1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1}); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -154,7 +153,7 @@ class MultiOutputFusionTest : public HloTestBase { dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, - dot_dnums)); + dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -179,10 +178,9 @@ class MultiOutputFusionTest : public HloTestBase { Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); - Literal expect = - std::move(*LiteralUtil::CreateR1({size * 1.5f * 3.5f})); + Literal expect = LiteralUtil::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); + EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_)); } }; @@ -219,10 +217,9 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { LiteralUtil::CreateR0(1.0)), LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(3.0), LiteralUtil::CreateR0(4))); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0(42)), result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -248,9 +245,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0, -1.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -281,9 +277,8 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = LiteralUtil::CreateR1({1.0, 2.0, 3.0}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); - LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, result); } const char* const kScalarOps = R"( @@ -325,13 +320,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -357,13 +351,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -390,13 +383,12 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), - LiteralUtil::CreateR1({36, 64}), - LiteralUtil::CreateR1({66, 138})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({14, 22}), + LiteralUtil::CreateR1({36, 64}), + LiteralUtil::CreateR1({66, 138})), + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -423,14 +415,13 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -457,15 +448,14 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{6, 8}, {10, 12}}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR2({{25, 36}, {49, 64}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -493,16 +483,15 @@ XLA_TEST_F(MultiOutputFusionTest, .ValueOrDie(); auto param = LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1({14, 22}), LiteralUtil::CreateR3( {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), LiteralUtil::CreateR3( {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -531,13 +520,13 @@ XLA_TEST_F(MultiOutputFusionTest, LiteralUtil::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); auto init1 = LiteralUtil::CreateR0(5); auto init2 = LiteralUtil::CreateR0(6); - std::unique_ptr result = ExecuteNoHloPasses( - std::move(module), {param.get(), init1.get(), init2.get()}); + Literal result = + ExecuteNoHloPasses(std::move(module), {¶m, &init1, &init2}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{167, 172}, {176, 180}}), LiteralUtil::CreateR2({{6, 6}, {6, 8}})), - *result)); + result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -566,10 +555,9 @@ XLA_TEST_F(MultiOutputFusionTest, auto param = LiteralUtil::CreateR3( {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); - std::unique_ptr result = - ExecuteNoHloPasses(std::move(module), {param.get()}); + Literal result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned( + LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR2({{3, 7}, {11, 15}}), LiteralUtil::CreateR2({{5, 16}, {36, 64}}), LiteralUtil::CreateR3( @@ -577,7 +565,7 @@ XLA_TEST_F(MultiOutputFusionTest, {Eigen::half(3), Eigen::half(4)}}, {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}})), - *result)); + result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc index 0a0426adcbc1b5b89be0841fa2c4204e2b65abf4..f2460822a61fef11e5c35c731fa6eca5df72b60b 100644 --- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc +++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc @@ -70,7 +70,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { GetTupleElement(result_tuple, 0); TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -81,41 +81,41 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInWhile) { VLOG(1) << "Transferring trip count to computation"; // Transfer number of iterations to Infeed. TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(1))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(1))); // Pick up value from outfeed { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 1); + EXPECT_EQ(r.Get({}), 1); } VLOG(1) << "Writing data to infeed"; // Transfer some stuff to Infeed for use inside of loop. TF_ASSERT_OK(local_client_->TransferToInfeed( - *LiteralUtil::CreateR1({10, 20}))); + LiteralUtil::CreateR1({10, 20}))); // Pick up value from outfeed { VLOG(1) << "Reading from body outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&xfeed_shape)); - EXPECT_EQ(r->Get({0}), 11); - EXPECT_EQ(r->Get({1}), 21); + EXPECT_EQ(r.Get({0}), 11); + EXPECT_EQ(r.Get({1}), 21); } { VLOG(1) << "Reading from condition outfeed"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&int_shape)); - EXPECT_EQ(r->Get({}), 0); + EXPECT_EQ(r.Get({}), 0); } // Joins the thread thread.reset(); - EXPECT_EQ(comp_result->Get({}), 0); + EXPECT_EQ(comp_result.Get({}), 0); } XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { @@ -145,7 +145,7 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build()); - std::unique_ptr comp_result; + Literal comp_result; std::unique_ptr thread( tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { @@ -154,12 +154,12 @@ XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) { })); TF_ASSERT_OK( - local_client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); + local_client_->TransferToInfeed(LiteralUtil::CreateR0(true))); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr r, + TF_ASSERT_OK_AND_ASSIGN(Literal r, local_client_->TransferFromOutfeed(&result_shape)); - EXPECT_EQ(r->Get({}), true); + EXPECT_EQ(r.Get({}), true); // Join the thread thread.reset(); diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index cbeddffacfa4a0fc560e8b9f9a8d7bd23ff32e55..6e98167739c234fae335bcc9e024423e7fc87197 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } @@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - Pad(AddParam(*LiteralUtil::CreateR1({}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, DefaultErrorSpec()); } @@ -123,8 +123,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - Pad(AddParam(*LiteralUtil::CreateR1({1, 2, 3}), &b), - AddParam(*LiteralUtil::CreateR0(0.1), &b), padding_config); + Pad(AddParam(LiteralUtil::CreateR1({1, 2, 3}), &b), + AddParam(LiteralUtil::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } @@ -132,7 +132,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) { XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { XlaBuilder b(TestName()); Pad(AddParam(Array4D(2, 0, 3, 2), &b), - AddParam(*LiteralUtil::CreateR0(1.5), &b), + AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, DefaultErrorSpec()); @@ -148,7 +148,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0(1.5), &b), + Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(2, 3, 3, 2); @@ -168,7 +168,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); Pad(AddParam(input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = absl::make_unique>(8, 5, 1, 1); @@ -208,10 +208,10 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -254,10 +254,10 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = input->Relayout(layout); + input = input.Relayout(layout); - Pad(AddParam(*input, &b), - AddParam(*LiteralUtil::CreateR0(pad_value), &b), padding_config); + Pad(AddParam(input, &b), + AddParam(LiteralUtil::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -331,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - Pad(input, AddParam(*LiteralUtil::CreateR0(0.0f), &b), padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -353,8 +353,7 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - Pad(input, AddParam(*LiteralUtil::CreateR0(3.14f), &b), - padding_config); + Pad(input, AddParam(LiteralUtil::CreateR0(3.14f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); @@ -379,7 +378,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -407,7 +406,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -435,7 +434,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - Pad(input, AddParam(*LiteralUtil::CreateR0(2.718f), &b), + Pad(input, AddParam(LiteralUtil::CreateR0(2.718f), &b), padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); @@ -452,13 +451,12 @@ XLA_TEST_P(PadTestFloat, ReducePad) { XlaComputation add = CreateScalarAddComputation(FloatType(), &b); auto reduce = - Reduce(input, AddParam(*LiteralUtil::CreateR0(0.0), &b), add, {0}); + Reduce(input, AddParam(LiteralUtil::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - Pad(reduce, AddParam(*LiteralUtil::CreateR0(0.0f), &b), - padding_config); + Pad(reduce, AddParam(LiteralUtil::CreateR0(0.0f), &b), padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index f6c762e7a4bee91a26c4c2e033c3717fef6d91d0..dcb4c11c3ccab5992e1ea4fadf02fda8ff77e7ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -42,10 +42,9 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + Literal param0_literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0"); @@ -55,9 +54,9 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + Literal param0_literal = LiteralUtil::CreateR1({}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0"); @@ -67,10 +66,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal param0_literal = LiteralUtil::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -81,9 +79,9 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XlaBuilder builder(TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + Literal param0_literal = LiteralUtil::CreateR1U8(str); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(U8, {static_cast(str.size())}), @@ -94,10 +92,10 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0"); @@ -107,10 +105,10 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + Literal param0_literal = LiteralUtil::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0"); @@ -123,15 +121,15 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); // Use both parameters // @@ -154,9 +152,9 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + Literal literal = LiteralUtil::CreateR0(3.14159f); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); XlaBuilder builder(TestName()); Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2"); @@ -168,15 +166,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); - Parameter(&builder, 0, literal0->shape(), "param0"); + client_->TransferToServer(literal0).ConsumeValueOrDie(); + Parameter(&builder, 0, literal0.shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + Literal literal1 = LiteralUtil::CreateR1({10, 20}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); - Parameter(&builder, 1, literal1->shape(), "param1"); + client_->TransferToServer(literal1).ConsumeValueOrDie(); + Parameter(&builder, 1, literal1.shape(), "param1"); ComputeAndCompareR1(&builder, {10, 20}, {param0_data.get(), param1_data.get()}, @@ -188,18 +186,17 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. XlaBuilder builder(TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + Literal literal0 = LiteralUtil::CreateR1({1, 2}); std::unique_ptr param0_data = - client_->TransferToServer(*literal0).ConsumeValueOrDie(); + client_->TransferToServer(literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + Literal literal1 = LiteralUtil::CreateR1({10, 20, 30}); std::unique_ptr param1_data = - client_->TransferToServer(*literal1).ConsumeValueOrDie(); + client_->TransferToServer(literal1).ConsumeValueOrDie(); - auto param0 = Parameter(&builder, 0, literal0->shape(), "param0"); - auto param1 = Parameter(&builder, 1, literal1->shape(), "param1"); - auto param2 = Parameter(&builder, 2, literal1->shape(), "param2"); + auto param0 = Parameter(&builder, 0, literal0.shape(), "param0"); + auto param1 = Parameter(&builder, 1, literal1.shape(), "param1"); + auto param2 = Parameter(&builder, 2, literal1.shape(), "param2"); // This add is unused. Add(param1, param2); @@ -233,10 +230,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + Literal literal = LiteralUtil::CreateR1(sum_value); param_data_owner.push_back( - client_->TransferToServer(*literal).ConsumeValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + client_->TransferToServer(literal).ConsumeValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -268,10 +265,10 @@ XLA_TEST_F(ParamsTest, constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR0(i); + Literal literal = LiteralUtil::CreateR0(i); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); sum_handle = Add(sum_handle, param); } @@ -300,10 +297,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); sum_handle = Add(sum_handle, param); } @@ -321,13 +318,14 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({target + i, target + i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } // Test large number of parameters flowing into a while-loop. @@ -356,23 +354,23 @@ XLA_TEST_F(ParamsTest, std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { - std::unique_ptr literal = LiteralUtil::CreateR1({i, i}); + Literal literal = LiteralUtil::CreateR1({i, i}); param_data_owner.push_back( - std::move(client_->TransferToServer(*literal)).ValueOrDie()); - XlaOp param = Parameter(&builder, i, literal->shape(), "param"); + std::move(client_->TransferToServer(literal)).ValueOrDie()); + XlaOp param = Parameter(&builder, i, literal.shape(), "param"); params.push_back(param); - parameter_shapes.push_back(literal->shape()); + parameter_shapes.push_back(literal.shape()); } // Add bool parameter for the loop condition. Use a parameter HLO instead of a // constant because DCE may eliminate the while-body otherwise. - std::unique_ptr bool_literal = LiteralUtil::CreateR0(false); + Literal bool_literal = LiteralUtil::CreateR0(false); param_data_owner.push_back( - std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + std::move(client_->TransferToServer(bool_literal)).ValueOrDie()); XlaOp bool_param = - Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param"); + Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param"); params.push_back(bool_param); - parameter_shapes.push_back(bool_literal->shape()); + parameter_shapes.push_back(bool_literal.shape()); auto init = Tuple(&builder, params); @@ -420,13 +418,14 @@ XLA_TEST_F(ParamsTest, param_data.push_back(data.get()); } - std::vector> elements; + std::vector elements; std::vector ptrs; + elements.reserve(kParamCount); for (int i = 0; i < kParamCount; ++i) { elements.push_back(LiteralUtil::CreateR1({i, i})); - ptrs.push_back(elements.back().get()); + ptrs.push_back(&elements.back()); } - ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data); + ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data); } #endif @@ -443,9 +442,9 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1, 2, 3}), + LiteralUtil::CreateR1({4, 5, 6}), })) .ConsumeValueOrDie(); @@ -457,34 +456,34 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); XlaBuilder builder(TestName()); - Parameter(&builder, 0, literal->shape(), "input"); + Parameter(&builder, 0, literal.shape(), "input"); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3)); + client_->TransferToServer(literal).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3)); } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + Literal literal = LiteralUtil::CreateR2({ {1, 3}, {2, 4}, }); - const Shape original = literal->shape(); + const Shape original = literal.shape(); { // Reverse the layout present in original, and make that the layout of the // literal. @@ -492,9 +491,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape_do_not_use()->mutable_layout() = + *literal.mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, literal->Get({0, 1})); + ASSERT_EQ(2, literal.Get({0, 1})); } // Use the original shape in building the computation. XlaBuilder builder(TestName()); @@ -503,7 +502,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { Slice(input, {0, 1}, {1, 2}, {1, 1}); std::unique_ptr data = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + client_->TransferToServer(literal).ConsumeValueOrDie(); // Check that we got the off-diagonal value that we expected. Array2D expected(1, 1); expected(0, 0) = 2; diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5f322b768d8620cb64a79bb8fca5fecf282f28f5..8f2c26f0eea9c7a3b33cd77e5977924c1659535a 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,8 +37,7 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - std::unique_ptr UniformTest(T a, T b, absl::Span dims, - int64 seed = 42); + Literal UniformTest(T a, T b, absl::Span dims, int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each @@ -49,9 +48,8 @@ class PrngTest : public ClientLibraryTestBase { }; template -std::unique_ptr PrngTest::UniformTest(T a, T b, - absl::Span dims, - int64 seed) { +Literal PrngTest::UniformTest(T a, T b, absl::Span dims, + int64 seed) { XlaBuilder builder(TestName()); RngUniform( ConstantR0(&builder, a), ConstantR0(&builder, b), @@ -60,8 +58,8 @@ std::unique_ptr PrngTest::UniformTest(T a, T b, SetSeed(seed); auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - actual->EachCell([=](absl::Span, T value) { + EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions())); + actual.EachCell([=](absl::Span, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); @@ -116,11 +114,10 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { constexpr int64 count = 100; for (int64 seed = 0; seed < count; ++seed) { auto result = UniformTest(low, high, {}, /*seed=*/seed); - result->Literal::EachCell( - [&](absl::Span, bfloat16 value) { - int64 index = static_cast((value - low) / interval); - counts[index]++; - }); + result.EachCell([&](absl::Span, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); } // Each bucket should have similar amount of counts. That is, not more than // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 @@ -149,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, auto actual = ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); - actual->EachCell( + actual.EachCell( [&counts](absl::Span, int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { @@ -192,12 +189,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { }; XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, - client_->TransferToServer(*param0_literal)); + client_->TransferToServer(param0_literal)); - auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0"); + auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto fn = build_sum_rng(builder); Map(&builder, {param0}, fn, {0}); @@ -210,12 +207,11 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), - ShapeUtil::ElementsIn(param0_literal->shape())); - for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { - EXPECT_GE(actual->data()[i], param0_literal->data()[i]); - EXPECT_LT(actual->data()[i], - param0_literal->data()[i] + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()), + ShapeUtil::ElementsIn(param0_literal.shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) { + EXPECT_GE(actual.data()[i], param0_literal.data()[i]); + EXPECT_LT(actual.data()[i], param0_literal.data()[i] + 1.0f); } } @@ -238,15 +234,15 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); - std::unique_ptr result1; + Literal result1; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options1)); } - std::unique_ptr result2; - std::unique_ptr result3; + Literal result2; + Literal result3; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -257,9 +253,9 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options1)); } - std::unique_ptr result4; - std::unique_ptr result5; - std::unique_ptr result6; + Literal result4; + Literal result5; + Literal result6; { TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation()); TF_ASSERT_OK_AND_ASSIGN( @@ -273,11 +269,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); - EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); - EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index 9af9ea4a2229bb6ca7c3561350f11837f5072a2c..c9096fb29b2019796c42b69de80c63b5fc7c5c3a 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -92,7 +92,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { *reduce_input_shape->mutable_layout() = LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); - std::unique_ptr reduce_input = LiteralUtil::CreateR4( + Literal reduce_input = LiteralUtil::CreateR4( {{ /*i0=0*/ {/*i1=0*/ {-0.246092796, -0.179497838, -0.161181688}, diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 0916a07f4fa99af6cf25441fa8558a558bfa032f..26e2bfde5cdc19657640f24f31bc008d09ad7106 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -231,11 +231,10 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = - LiteralUtil::CreateR1({input_values}); + Literal a_literal = LiteralUtil::CreateR1({input_values}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); ReducePrecision(a, exponent_bits, mantissa_bits); @@ -255,10 +254,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // Abs doesn't affect resolution. auto abs = Abs(a); @@ -284,10 +283,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -310,10 +309,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -334,10 +333,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); @@ -359,10 +358,10 @@ XLA_TEST_F(ReducePrecisionInsertionTest, DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR1({1.00001}); + Literal a_literal = LiteralUtil::CreateR1({1.00001}); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - auto a = Parameter(&builder, 0, a_literal->shape(), "a"); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); + auto a = Parameter(&builder, 0, a_literal.shape(), "a"); // These two operations should be fused by any reasonable backend. auto abs = Abs(a); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 8c62adea231d1d3197c6e483d58008b1577b156d..83997cdac21c437d460dabdbdfdb31100b1359af 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -81,9 +81,9 @@ class ReduceTest : public ClientLibraryTestBase { }, 4); // clang-format on CHECK(ShapeUtil::Equal( - literal_3d_->shape(), + literal_3d_.shape(), ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3}))) - << literal_3d_->shape().ShortDebugString(); + << literal_3d_.shape().ShortDebugString(); } // Runs an R1 => R0 reduction test with the given number of elements. @@ -102,10 +102,9 @@ class ReduceTest : public ClientLibraryTestBase { input_data[i] *= -1; } } - std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (float item : input_data) { @@ -134,9 +133,9 @@ class ReduceTest : public ClientLibraryTestBase { Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + Literal input_literal = LiteralUtil::CreateR1(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); bool expected = and_reduce; for (bool item : input_data) { @@ -175,12 +174,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(0, 1); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::array expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -209,12 +207,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); float expected = 0.0; for (int64 rowno = 0; rowno < rows; ++rowno) { @@ -237,12 +234,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -295,12 +291,11 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillUnique(initial_value); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + input_literal.Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); // NativeT can be bool, and std::vector does not convert to // Span. @@ -352,8 +347,8 @@ class ReduceTest : public ClientLibraryTestBase { reference_reduction_function_for_uints, unsigned_int_identity); } - std::unique_ptr literal_2d_; - std::unique_ptr literal_3d_; + Literal literal_2d_; + Literal literal_3d_; uint32 seed_ = 0xdeadbeef; }; @@ -450,11 +445,10 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -482,11 +476,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); + Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 colno = 0; colno < cols; ++colno) { @@ -511,10 +504,9 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2}); Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - MakeFakeLiteral(input_shape)); + TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape)); - ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4)); + ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4)); } XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { @@ -531,10 +523,9 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); std::vector expected; for (int64 major = 0; major < 2; ++major) { @@ -595,7 +586,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { Array2D input(300, 250); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; input.Each( @@ -610,7 +601,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { Array2D input(150, 130); input.FillRandom(214.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input); - Reduce(ConstantLiteral(&builder, *input_literal), + Reduce(ConstantLiteral(&builder, input_literal), ConstantR0(&builder, FLT_MAX), min, {0, 1}); auto input_min = FLT_MAX; @@ -627,7 +618,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::max()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1}); ComputeAndCompareR0(&builder, 1, {}); } @@ -639,14 +630,14 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { auto initial_value = ConstantR0(&builder, std::numeric_limits::min()); - Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1}); + Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1}); ComputeAndCompareR0(&builder, 2, {}); } // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -657,7 +648,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -667,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XlaBuilder builder("reduce_among_y"); - auto m = ConstantLiteral(&builder, *literal_2d_); + auto m = ConstantLiteral(&builder, literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -677,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1, 2}); @@ -687,7 +678,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1}); @@ -697,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0, 1, 2}); @@ -707,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {0}); @@ -722,7 +713,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {1}); @@ -739,7 +730,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { XlaBuilder builder(TestName()); - auto m = ConstantLiteral(&builder, *literal_3d_); + auto m = ConstantLiteral(&builder, literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); Reduce(m, ConstantR0(&builder, 0.0f), add, {2}); @@ -824,12 +815,12 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); input_literal = - input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); + input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); auto input_activations = - Parameter(&builder, 0, input_literal->shape(), "input"); + Parameter(&builder, 0, input_literal.shape(), "input"); XlaComputation add = CreateScalarAddComputation(F32, &builder); Reduce(input_activations, ConstantR0(&builder, 0.0f), add, GetParam().reduce_dims); @@ -866,21 +857,17 @@ INSTANTIATE_TEST_CASE_P( BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}}, BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}})); -// TODO(b/64093391) Disabled on GPU due to an assertion failure when running -// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on -// 2017-07-26. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { +XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) { XlaBuilder builder(TestName()); XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); auto a = ConstantR0(&builder, 2.0f); auto a2 = Abs(a); - std::unique_ptr b_literal = - LiteralUtil::CreateR1({1.0f, 4.0f}); + Literal b_literal = LiteralUtil::CreateR1({1.0f, 4.0f}); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); - auto b = Parameter(&builder, 0, b_literal->shape(), "b"); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); + auto b = Parameter(&builder, 0, b_literal.shape(), "b"); Reduce(b, a2, max_f32, {0}); ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); @@ -907,9 +894,9 @@ class ReduceInitializerTest : public ReduceTest { std::vector input_arr(num_elems, std::numeric_limits::lowest()); auto input_literal = LiteralUtil::CreateR1(input_arr); auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init, - max_fn, {0}); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn, + {0}); ComputeAndCompareR0(&builder, initializer, {input_data.get()}); } @@ -955,13 +942,12 @@ XLA_TEST_F(ReduceTest, ReduceIdentity) { float operand[] = {42.0f}; float init = 58.5f; float expected = 42.0f; - std::unique_ptr input_literal = - LiteralUtil::CreateR1(operand); + Literal input_literal = LiteralUtil::CreateR1(operand); std::unique_ptr input_global_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - std::unique_ptr input_literal2 = LiteralUtil::CreateR0(init); + client_->TransferToServer(input_literal).ConsumeValueOrDie(); + Literal input_literal2 = LiteralUtil::CreateR0(init); std::unique_ptr input_global_data2 = - client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + client_->TransferToServer(input_literal2).ConsumeValueOrDie(); ComputeAndCompareR0( &builder, expected, {input_global_data.get(), input_global_data2.get()}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 997880a018a264de7b0623d27997defdfc68f14a..63491a90bf2634a53591e2ab431781f3c4237681 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -73,7 +73,7 @@ class ReduceWindowTest : public ::testing::WithParamInterface, absl::Span window_dimensions, absl::Span window_strides, Padding padding) { - auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), + auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), @@ -107,9 +107,9 @@ class ReduceWindowTest : public ::testing::WithParamInterface, TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); + LiteralUtil::CreateR1({1, 1, 1, 1}), &builder_); const auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); ReduceWindow(input, init_value, CreateScalarAddComputation(FloatType(), &builder_), @@ -124,31 +124,31 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { // Regression test for b/68964348. TEST_P(ReduceWindowTest, R0ReduceWindow) { const auto input = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(42.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(42.0), &builder_); const auto init = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(1.0), &builder_); + CreateConstantFromLiteral(LiteralUtil::CreateR0(1.0), &builder_); ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0(43.0), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0(43.0), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride2) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({100, 1}), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({100, 1}), {}, ErrorSpec(0.00001)); } TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, Padding::kSame); ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), + LiteralUtil::CreateR1({1000, 100, 10, 1, 1}), {}, ErrorSpec(0.00001)); } @@ -161,7 +161,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -176,7 +176,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -190,7 +190,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -207,7 +207,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {}, DefaultErrorSpec()); } @@ -229,8 +229,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { @@ -252,8 +252,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests the super windowing logic w.r.t handling prime number of windows in a @@ -277,8 +277,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { @@ -294,8 +294,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. @@ -313,12 +313,12 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { auto lhs = Parameter(b.get(), 0, scalar, "lhs"); auto rhs = Parameter(b.get(), 1, scalar, "rhs"); Min(Add(lhs, rhs), - CreateConstantFromLiteral(*LiteralUtil::CreateR0(8.0f), b.get())); + CreateConstantFromLiteral(LiteralUtil::CreateR0(8.0f), b.get())); XlaComputation reduce_fn = b->BuildAndNoteError(); ReduceWindow( input, - CreateConstantFromLiteral(*LiteralUtil::CreateR0(0.0f), &builder_), + CreateConstantFromLiteral(LiteralUtil::CreateR0(0.0f), &builder_), reduce_fn, /*window_dimensions=*/{1, 1, 2, 1}, /*window_strides=*/{1, 1, 1, 1}, padding); @@ -332,19 +332,18 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected), {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); input_array.FillRandom(2.f, 2.f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -352,7 +351,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -360,9 +359,9 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - auto arg_literal = absl::make_unique(shape); - arg_literal->PopulateWithValue(1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + Literal arg_literal(shape); + arg_literal.PopulateWithValue(1.0f); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); @@ -371,39 +370,38 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - auto expected = absl::make_unique(result_shape); - expected->PopulateWithValue(27.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + Literal expected(result_shape); + expected.PopulateWithValue(27.0f); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = + Literal arg_literal = LiteralUtil::CreateFullWithDescendingLayout(input_dims, 1.0f); - const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + const auto input = CreateConstantFromLiteral(arg_literal, &builder_); Padding padding = Padding::kValid; ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateFullWithDescendingLayout(output_dims, 9.0f); - ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -413,19 +411,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -435,19 +432,18 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp input; auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "parameter", &builder_, &input); + 0, input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -457,7 +453,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {input_data.get()}, DefaultErrorSpec()); } @@ -478,18 +474,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result), - {}, DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {}, + DefaultErrorSpec()); } XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); ComputeAndCompareLiteral( &builder_, - *LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + LiteralUtil::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, DefaultErrorSpec()); } @@ -504,9 +500,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -521,9 +517,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; const auto input = CreateConstantFromLiteral( - *LiteralUtil::CreateR1(input_vector), &builder_); + LiteralUtil::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {1}, Padding::kValid); - ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1({1088}), {}, + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1({1088}), {}, DefaultErrorSpec()); } @@ -540,9 +536,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { @@ -556,9 +551,8 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareLiteral(&builder_, - *LiteralUtil::CreateFromArray(*res), {}, - DefaultErrorSpec()); + ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), + {}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, @@ -594,7 +588,7 @@ string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -613,12 +607,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillIota(1); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(4); @@ -627,9 +620,16 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, } auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); - auto computation = param.reducer == kAdd + auto reducer = param.reducer; + if (use_bfloat16() && Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + + auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); ReduceWindowWithGeneralPadding( @@ -640,8 +640,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window_strides=*/param.strides, /*padding=*/padding); - CHECK(param.reducer == kAdd || param.reducer == kMax); - auto reduce_func = param.reducer == kAdd + CHECK(reducer == kAdd || reducer == kMax); + auto reduce_func = reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; std::unique_ptr> expected = @@ -652,12 +652,11 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - std::unique_ptr expected_literal = - LiteralUtil::CreateFromArray(*expected); + Literal expected_literal = LiteralUtil::CreateFromArray(*expected); const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( - input_literal->shape().element_type(), - AsInt64Slice(expected_literal->shape().dimensions()), param.layout); - ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + input_literal.shape().element_type(), + AsInt64Slice(expected_literal.shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()}, DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -809,6 +808,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{1, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + /*window_bounds=*/{1, 64, 64, 1}, + /*strides=*/{1, 64, 64, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 0, 2, 1}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, + /*window_bounds=*/{112, 112, 1, 8}, + /*strides=*/{112, 112, 1, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -930,6 +945,27 @@ struct R3ReduceWindowTestData { {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, }; string R3ReduceWindowTestDataToString( @@ -944,7 +980,7 @@ string R3ReduceWindowTestDataToString( param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -956,35 +992,41 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } }; -TEST_P(R3ReduceWindowTest, Add) { +TEST_P(R3ReduceWindowTest, DoIt) { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], - param.base_bounds[2], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + param.base_bounds[2]); + input.FillRandom(0.1f, 0.1f); + Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + auto reducer = param.reducer; + if (use_bfloat16()) { + input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); + if (Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + } - XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input"); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); + + auto computation = reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - auto expected = ReferenceUtil::ReduceWindow3DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1079,7 +1121,7 @@ string R2ReduceWindowTestDataToString( param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1093,16 +1135,14 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); + Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(2); for (int i = 0; i < 2; ++i) { @@ -1112,7 +1152,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1128,7 +1168,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected), {input_arg.get()}, DefaultErrorSpec()); } }; @@ -1282,7 +1322,7 @@ string R1ReduceWindowTestDataToString( "__pad_high_", absl::StrJoin(param.pad_high, "x"), "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = absl::StrCat(str, "_bfloat16"); + absl::StrAppend(&str, "_bfloat16"); } return str; } @@ -1302,11 +1342,11 @@ TEST_P(R1ReduceWindowTest, DoIt) { const float kInitValue = 0.0f; std::vector input_vector(param.base_bounds[0]); std::iota(std::begin(input_vector), std::end(input_vector), 0); - std::unique_ptr input_literal = + Literal input_literal = LiteralUtil::CreateR1(absl::Span(input_vector)); XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + auto input_arg = + CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, ¶meter); std::vector> padding(1); padding[0] = {param.pad_low[0], param.pad_high[0]}; @@ -1315,7 +1355,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); auto init_value = - CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b); ReduceWindowWithGeneralPadding( /*operand=*/parameter, /*init_value=*/init_value, @@ -1334,7 +1374,7 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1(*expected), + ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1(*expected), {input_arg.get()}, DefaultErrorSpec()); } diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index d8914513819415368a628eab1f482f9644dd46b1..5cf87e565bf493167f5173588e7afa3b96282488 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -58,13 +58,13 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. - LiteralTestUtil::ExpectR0Equal(4, *literal); + LiteralTestUtil::ExpectR0Equal(4, literal); } XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { @@ -91,12 +91,12 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(LiteralUtil::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(LiteralUtil::CreateR0(3)) .ConsumeValueOrDie(); - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{x_data.get(), y_data.get()}, @@ -104,7 +104,7 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { .ConsumeValueOrDie(); // Expect 5. - LiteralTestUtil::ExpectR0Equal(5, *literal); + LiteralTestUtil::ExpectR0Equal(5, literal); } TEST_F(ReplayTest, MapPlusTwoOverR1) { @@ -136,13 +136,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); // Run it. - std::unique_ptr literal = + Literal literal = client_ ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. - LiteralTestUtil::ExpectR1Equal({3, 4, 5}, *literal); + LiteralTestUtil::ExpectR1Equal({3, 4, 5}, literal); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 17d12715f60f624c35169048121ca139d78a544f..dedc95b5ae8315185a35f786af42aad53bd7ad96 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -57,12 +57,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -70,12 +70,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -83,12 +83,12 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0}); auto expected_literal = LiteralUtil::CreateR1({1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -99,29 +99,29 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { input_array.Fill(1.0f); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter", &builder, ¶meter); auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(1.0f); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + Literal param0_literal = LiteralUtil::CreateR0(1.0f); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); auto a = Neg(parameter); Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); auto expected_literal = LiteralUtil::CreateR1({-1.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -130,25 +130,25 @@ XLA_TEST_P(ReshapeTest, Trivial0x3) { Array2D input_array(0, 3); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) { XlaBuilder builder(TestName()); - std::unique_ptr param0_literal = + Literal param0_literal = LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -157,11 +157,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x0) { Array2D input_array(3, 0); auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -170,11 +170,11 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -183,11 +183,11 @@ XLA_TEST_P(ReshapeTest, Trivial3x1) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR2({{1.0f}, {2.0f}, {3.0f}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); auto expected_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -196,12 +196,12 @@ XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1({}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -211,13 +211,13 @@ XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); auto expected_literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -226,12 +226,12 @@ XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 2)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -241,14 +241,14 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); auto input_literal = LiteralUtil::CreateFromArray(*simple); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -258,14 +258,14 @@ XLA_TEST_P(ReshapeTest, TransposeAsReshape) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 4}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -274,11 +274,11 @@ XLA_TEST_P(ReshapeTest, Transpose0x4) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 4)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected_literal = LiteralUtil::CreateR2({{}, {}, {}, {}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -288,13 +288,13 @@ XLA_TEST_P(ReshapeTest, Transpose4x3) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Transpose(parameter, {1, 0}); auto expected = ReferenceUtil::TransposeArray2D(*a4x3); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -304,13 +304,13 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(6, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 3, 0, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 0, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -318,12 +318,12 @@ XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array4D(2, 3, 4, 0)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{24, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(24, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -334,14 +334,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{2, 6}); auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); auto expected_literal = LiteralUtil::CreateFromArray(*expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -349,12 +349,12 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(Array2D(0, 6)); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 0}); auto expected_literal = LiteralUtil::CreateFromArray(Array2D(3, 0)); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -365,14 +365,14 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); auto input_literal = LiteralUtil::CreateFromArray(*a4x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6}); Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); auto expected_literal = LiteralUtil::CreateFromArray(expected); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -391,14 +391,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -406,7 +406,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, /*new_sizes=*/{8, 3}); @@ -418,7 +418,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { {35, 36, 37}, {40, 41, 42}, {45, 46, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -426,14 +426,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{24}); auto expected_literal = LiteralUtil::CreateR1( {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -441,7 +441,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{8, 3}); @@ -453,7 +453,7 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { {45, 16, 26}, {36, 46, 17}, {27, 37, 47}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -461,14 +461,14 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, /*new_sizes=*/{2, 6, 2}); auto expected_literal = LiteralUtil::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -494,14 +494,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { t2x2x2x3.FillWithYX(*filler2x3); auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); auto expected_literal = LiteralUtil::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -519,14 +519,14 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 1, 1) = 7; auto input_literal = LiteralUtil::CreateFromArray(t); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 4}); auto expected_literal = LiteralUtil::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -547,7 +547,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) { Reshape(parameter, dimensions, {}); auto expected_literal = LiteralUtil::CreateR0(83.0f); - ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&b, expected_literal, {input.get()}, zero_error_spec_); } } @@ -556,7 +556,7 @@ XLA_TEST_P(ReshapeTest, BadDimensions) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {}, {}); EXPECT_THAT( @@ -568,7 +568,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) { XlaBuilder b(TestName()); auto input_literal = LiteralUtil::CreateR1({1.0f, 2.0f}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b, ¶meter); Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), @@ -604,7 +604,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); @@ -619,27 +619,26 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, {1, 0}); - std::unique_ptr actual = + Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); - std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(*expected); + expected = LiteralUtil::ConvertF32ToBF16(expected); } - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); @@ -653,20 +652,20 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { {{204, 205, 206, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { XlaBuilder builder(TestName()); - std::unique_ptr input_literal = LiteralUtil::CreateR2({ + Literal input_literal = LiteralUtil::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); @@ -680,7 +679,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { {{206, 7, 107, 207}}} }); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + ComputeAndCompareLiteral(&builder, expected_literal, {input.get()}, zero_error_spec_); } @@ -691,17 +690,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { Array4D input(2, 1, 1, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -712,17 +709,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { Array4D input(2, 1, 4, 1); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -734,12 +729,11 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { Array4D input(5, 10, 2, 3); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); @@ -749,7 +743,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, zero_error_spec_); } @@ -761,12 +755,11 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { input_array.Each( [&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); XlaComputation computation = builder.Build().ConsumeValueOrDie(); @@ -775,7 +768,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, {2, 3, 0, 1}); - std::unique_ptr output_literal = + Literal output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, &execution_options) @@ -784,10 +777,10 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal); - EXPECT_EQ(expected->data(), output_literal->data()); + auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); } else { - EXPECT_EQ(input_literal->data(), output_literal->data()); + EXPECT_EQ(input_literal.data(), output_literal.data()); } } @@ -798,12 +791,12 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); + ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()}); } XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { @@ -813,7 +806,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { XlaBuilder builder(TestName()); XlaOp parameter; - auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input", &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); @@ -830,7 +823,7 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); + ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()}); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { @@ -841,24 +834,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { @@ -869,24 +861,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { @@ -897,24 +888,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { @@ -926,24 +916,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({3, 2, 1, 0})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) - ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal) + .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { @@ -954,24 +943,23 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { Array4D input(bounds[0], bounds[1], bounds[2], bounds[3]); input.Each([&rng, &distribution](absl::Span /* indices */, float* cell) { *cell = distribution(rng); }); - std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( - input, LayoutUtil::MakeLayout({0, 1, 2, 3})); + Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout({0, 1, 2, 3})); XlaBuilder builder(TestName()); XlaOp parameter; - auto input_data = CreateParameterAndTransferLiteral( - 0, *input_literal, "input", &builder, ¶meter); + auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input", + &builder, ¶meter); Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = - LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) - ->Relayout(input_literal->shape().layout()); + Literal expected = + LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal) + .Relayout(input_literal.shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, - zero_error_spec_, &expected->shape()); + ComputeAndCompareLiteral(&builder, expected, {input_data.get()}, + zero_error_spec_, &expected.shape()); } #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 74ded82ddfae10c21fe98ec2e250b4eaecf95222..4e55b0d7ac4453d074500f3a7fda96cb5ab52c56 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -83,25 +83,25 @@ TEST_P(FloatReverseTest, Reverses) { ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); std::iota(input_vector.begin(), input_vector.end(), 0.0); auto r1_literal = LiteralUtil::CreateR1(input_vector); - auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie(); XlaBuilder builder(TestName()); - auto a = AddParam(*input_literal, &builder); + auto a = AddParam(input_literal, &builder); Rev(a, spec.reversal); - std::unique_ptr expected = input_literal->CloneToUnique(); + Literal expected = input_literal.Clone(); std::vector output_indices(spec.input_dims.size()); - expected->EachCell([&](absl::Span indices, float) { + expected.EachCell([&](absl::Span indices, float) { for (int64 i = 0; i < indices.size(); ++i) { output_indices[i] = indices[i]; } - float value = input_literal->Get(indices); + float value = input_literal.Get(indices); for (int64 dim : spec.reversal) { output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; } - expected->Set(output_indices, value); + expected.Set(output_indices, value); }); - ComputeAndCompareLiteral(&builder, *expected, {}); + ComputeAndCompareLiteral(&builder, expected, {}); } INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index e692b8c5d5e661587bac16a2992e35f92c4c0bd9..091a5d2cacce6ac5bf986776e5ec96612d08cc75 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -38,7 +38,7 @@ namespace { class RoundTripPackedLiteralTest : public ClientLibraryTestBase { protected: // Sends the literal to the server and retrieves it back. - std::unique_ptr RoundTripToServer(const Literal& original) { + Literal RoundTripToServer(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); return client_->Transfer(*data).ConsumeValueOrDie(); @@ -59,12 +59,12 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, actual->Get({0})); - EXPECT_EQ(24.0, actual->Get({1})); + EXPECT_EQ(42.0, actual.Get({0})); + EXPECT_EQ(24.0, actual.Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -87,18 +87,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({0, 1})); - EXPECT_EQ(64.0f, actual->Get({1, 0})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({0, 1})); + EXPECT_EQ(64.0f, actual.Get({1, 0})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -121,18 +120,17 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { std::unique_ptr f; TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f)); PackedLiteralReader reader(f.release()); - std::unique_ptr actual = - reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) - .ConsumeValueOrDie(); + Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout) + .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, actual->Get({0, 0})); - EXPECT_EQ(24.0f, actual->Get({1, 0})); - EXPECT_EQ(64.0f, actual->Get({0, 1})); - EXPECT_EQ(46.0f, actual->Get({1, 1})); + EXPECT_EQ(42.0f, actual.Get({0, 0})); + EXPECT_EQ(24.0f, actual.Get({1, 0})); + EXPECT_EQ(64.0f, actual.Get({0, 1})); + EXPECT_EQ(46.0f, actual.Get({1, 1})); - std::unique_ptr round_tripped = RoundTripToServer(*actual); - EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); + Literal round_tripped = RoundTripToServer(actual); + EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index a8193c2eac05ba4f0df339909f3e82a28ac35253..cd5a531603b0cb6e0f48f4dcd49891cbd5428602 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -39,69 +39,67 @@ class RoundTripTransferTest : public ClientLibraryTestBase { void RoundTripTest(const Literal& original) { std::unique_ptr data = client_->TransferToServer(original).ConsumeValueOrDie(); - std::unique_ptr result = - client_->Transfer(*data).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); + Literal result = client_->Transfer(*data).ConsumeValueOrDie(); + EXPECT_TRUE(LiteralTestUtil::Equal(original, result)); } }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(LiteralUtil::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(LiteralUtil::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(LiteralUtil::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(LiteralUtil::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(LiteralUtil::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(LiteralUtil::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -109,36 +107,35 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(LiteralUtil::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({1, 2}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1({}), + LiteralUtil::CreateR1({3, 4})})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(1.0), LiteralUtil::CreateR1({2, 3})})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(LiteralUtil::CreateR4FromArray4D(array4d)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 07460a7e01a5497aa6411ddb6866dddfc70f2068..1dd937a6d0656b53f9e7e0cb25acf80f0c3d59c0 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -161,9 +161,9 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr a_literal = LiteralUtil::CreateR0(value); + Literal a_literal = LiteralUtil::CreateR0(value); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); ComputeAndCompareR0(&builder, static_cast(value), {a_data.get()}); } @@ -225,20 +225,20 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + Literal a_literal = LiteralUtil::CreateR0(2.1f); + Literal b_literal = LiteralUtil::CreateR0(5.5f); + Literal c_literal = LiteralUtil::CreateR0(0.5f); std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + client_->TransferToServer(a_literal).ConsumeValueOrDie(); std::unique_ptr b_data = - client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + client_->TransferToServer(b_literal).ConsumeValueOrDie(); std::unique_ptr c_data = - client_->TransferToServer(*c_literal).ConsumeValueOrDie(); + client_->TransferToServer(c_literal).ConsumeValueOrDie(); - XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a"); - XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b"); - XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c"); + XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a"); + XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b"); + XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c"); Mul(Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -377,9 +377,9 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(div_computation, @@ -388,7 +388,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend / divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -419,9 +419,9 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { auto dividend_literal = LiteralUtil::CreateR0(dividend); auto divisor_literal = LiteralUtil::CreateR0(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, - client_->TransferToServer(*dividend_literal)); + client_->TransferToServer(dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, - client_->TransferToServer(*divisor_literal)); + client_->TransferToServer(divisor_literal)); auto actual_literal = client_ ->ExecuteAndTransfer(rem_computation, @@ -430,7 +430,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { .ConsumeValueOrDie(); auto expected_literal = LiteralUtil::CreateR0(dividend % divisor); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal)); } } } @@ -441,8 +441,8 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0(&builder, 80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); - TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); + Literal literal = LiteralUtil::CreateR0(87919); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index 1858dcea61241a2aeee11592a9b09f200763b25a..d20dba028a586fa7c93c96dca03c77e3668fa644 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -62,13 +62,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { @@ -92,13 +90,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 30}, {40, 60}, {70, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { @@ -123,13 +120,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { @@ -154,13 +149,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { @@ -185,13 +178,12 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2( + Literal operand = LiteralUtil::CreateR2( {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({2, 1}); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR1({2, 1}); + Literal updates = LiteralUtil::CreateR2({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { @@ -216,13 +208,11 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { @@ -247,13 +237,12 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 2}, {2, 1}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal scatter_indices = LiteralUtil::CreateR2({{0, 2}, {2, 1}}); + Literal updates = LiteralUtil::CreateR3( {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { @@ -277,15 +266,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { @@ -309,15 +296,13 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // {{-4, 4}, {-5, 5}, {-6, 6}}, // {{-7, 7}, {-8, 8}, {-9, 9}}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{0, 0}, {1, 0}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{0, 0}, {1, 0}}); + Literal updates = LiteralUtil::CreateR2({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { @@ -341,12 +326,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({1, 1}); - std::unique_ptr updates = LiteralUtil::CreateR2({{10}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR1({1, 1}); + Literal updates = LiteralUtil::CreateR2({{10}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { @@ -370,13 +354,11 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR2({{2, 1}, {1, 1}}); - std::unique_ptr updates = - LiteralUtil::CreateR3({{{10}}, {{20}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal scatter_indices = LiteralUtil::CreateR2({{2, 1}, {1, 1}}); + Literal updates = LiteralUtil::CreateR3({{{10}}, {{20}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ZeroDimBounds) { @@ -400,11 +382,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR2({{}, {}, {}}); - std::unique_ptr scatter_indices = - LiteralUtil::CreateR1({0, 2}); - std::unique_ptr updates = LiteralUtil::CreateR2({{}, {}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR2({{}, {}, {}}); + Literal scatter_indices = LiteralUtil::CreateR1({0, 2}); + Literal updates = LiteralUtil::CreateR2({{}, {}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { @@ -429,12 +410,11 @@ ENTRY main { index_vector_dim=2 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({0, 1, 2}); - std::unique_ptr scatter_indices = + Literal operand = LiteralUtil::CreateR1({0, 1, 2}); + Literal scatter_indices = LiteralUtil::CreateR3({{{0}, {1}}, {{2}, {1}}}); - std::unique_ptr updates = - LiteralUtil::CreateR2({{10, 20}, {30, 40}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal updates = LiteralUtil::CreateR2({{10, 20}, {30, 40}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { @@ -458,13 +438,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { @@ -488,13 +468,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, NegativeIndex) { @@ -518,13 +498,13 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = + Literal operand = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR2( + Literal scatter_indices = LiteralUtil::CreateR2( {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); - std::unique_ptr updates = LiteralUtil::CreateR3( + Literal updates = LiteralUtil::CreateR3( {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, OneScalarIndex) { @@ -548,12 +528,12 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR3( + Literal operand = LiteralUtil::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR3({{{10, 20}, {30, 40}, {50, 60}}}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, ScalarUpdate) { @@ -577,10 +557,10 @@ ENTRY main { index_vector_dim=0 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3, 4}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR0(1); - std::unique_ptr updates = LiteralUtil::CreateR0(25); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3, 4}); + Literal scatter_indices = LiteralUtil::CreateR0(1); + Literal updates = LiteralUtil::CreateR0(25); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } XLA_TEST_F(ScatterTest, EmptyIndices) { @@ -604,10 +584,10 @@ ENTRY main { index_vector_dim=1 } )"; - std::unique_ptr operand = LiteralUtil::CreateR1({1, 2, 3}); - std::unique_ptr scatter_indices = LiteralUtil::CreateR1({}); - std::unique_ptr updates = LiteralUtil::CreateR1({}); - RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); + Literal operand = LiteralUtil::CreateR1({1, 2, 3}); + Literal scatter_indices = LiteralUtil::CreateR1({}); + Literal updates = LiteralUtil::CreateR1({}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); } } // namespace diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c9a58aefb4acc066c10e98aea46375523cf554d0..a40c2d7de6eceea489004f5266d060f26da5d1a8 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -176,8 +176,8 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) { XlaBuilder builder(TestName()); auto original = ConstantR4FromArray4D(&builder, values); Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1}); - ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001), - &expected_literal->shape()); + ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001), + &expected_literal.shape()); } struct R1Spec { @@ -201,7 +201,7 @@ class SliceR1Test : public ClientLibraryTestBase, auto literal = LiteralUtil::CreateR1(input); XlaBuilder builder(TestName()); - auto original = Parameter(&builder, 0, literal->shape(), "p0"); + auto original = Parameter(&builder, 0, literal.shape(), "p0"); Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -213,7 +213,7 @@ class SliceR1Test : public ClientLibraryTestBase, } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -376,11 +376,11 @@ XLA_TEST_P(SliceR2Test, DoIt) { input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = Parameter(&builder, 0, literal->shape(), "p0"); + auto a = Parameter(&builder, 0, literal.shape(), "p0"); Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2(&builder, *expected, {arg.get()}); @@ -467,9 +467,9 @@ class SliceR4Test : public ClientLibraryTestBase, XlaBuilder builder(TestName()); auto literal = LiteralUtil::CreateR4FromArray4DWithLayout( values, LayoutUtil::MakeLayout(spec.input_layout)); - auto parameter = Parameter(&builder, 0, literal->shape(), "p0"); + auto parameter = Parameter(&builder, 0, literal.shape(), "p0"); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, - client_->TransferToServer(*literal)); + client_->TransferToServer(literal)); Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001)); } diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c20a7c8fe49cd6b9161251488b85e08459f68865..5155f0c652c7c6dbba60c421159494fa28072090 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -116,13 +116,14 @@ void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine, // array. This is uniqueness is best-effort only. Some types (half and bfloat16) // are not supported and uniqueness cannot be guaranteed if the number of // elements exceeds the number of different values supported by the type. -StatusOr> MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) { +StatusOr MakeFakeLiteralInternal(const Shape& shape, + std::minstd_rand0* engine, + bool no_duplicates) { if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; + std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN( - std::unique_ptr element, + Literal element, MakeFakeLiteralInternal(element_shape, engine, no_duplicates)); elements.push_back(std::move(element)); } @@ -131,60 +132,52 @@ StatusOr> MakeFakeLiteralInternal( if (engine == nullptr) { return Literal::CreateFromShape(shape); } - auto literal = absl::make_unique(shape); + Literal literal(shape); switch (shape.element_type()) { case BF16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F16: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F32: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case F64: - PopulateWithRandomFloatingPointData(literal.get(), engine, + PopulateWithRandomFloatingPointData(&literal, engine, no_duplicates); break; case S8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U8: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U16: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U32: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case S64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case U64: - PopulateWithRandomIntegralData(literal.get(), engine, - no_duplicates); + PopulateWithRandomIntegralData(&literal, engine, no_duplicates); break; case PRED: { std::uniform_int_distribution generator(0, 1); TF_CHECK_OK( - literal->Populate([&](absl::Span /*indices*/) { + literal.Populate([&](absl::Span /*indices*/) { return generator(*engine); })); break; @@ -236,8 +229,8 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr MakeRandomIndex(absl::Span index_space, - std::minstd_rand0* engine) { +Literal MakeRandomIndex(absl::Span index_space, + std::minstd_rand0* engine) { std::vector start_indices(index_space.size()); if (engine != nullptr) { for (int i = 0; i < index_space.size(); ++i) { @@ -293,7 +286,7 @@ std::vector FindConstrainedUses( // no constrained uses in the dataflow graph. If such constraints exist, // generate a constrained literal (either bounded in the case of indices, or // zero in the case of init_values for reductions). -StatusOr> CreateLiteralForConstrainedUses( +StatusOr CreateLiteralForConstrainedUses( const absl::Span constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { std::vector index_space; @@ -358,9 +351,9 @@ StatusOr> CreateLiteralForConstrainedUses( } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: - return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::Zero(param.shape().element_type()); case ConstantType::kOne: - return LiteralUtil::One(param.shape().element_type()).CloneToUnique(); + return LiteralUtil::One(param.shape().element_type()); case ConstantType::kUnknown: // We want the identity element for the computation, but we don't really // know what it is - so any value we generate will be just as wrong. @@ -374,34 +367,33 @@ StatusOr> CreateLiteralForConstrainedUses( // Given a module entry parameter, use the dataflow analysis to see if a // special case literal must be created, or if we can generate fake data. -StatusOr> MakeConstrainedArgument( - const HloDataflowAnalysis& dataflow, const HloInstruction& param, - std::minstd_rand0* engine) { +StatusOr MakeConstrainedArgument(const HloDataflowAnalysis& dataflow, + const HloInstruction& param, + std::minstd_rand0* engine) { const auto constrained_uses = FindConstrainedUses(dataflow, param); return CreateLiteralForConstrainedUses(constrained_uses, param, engine); } } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random) { +StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false); } -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random) { +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random) { auto engine = pseudo_random ? absl::make_unique() : nullptr; return MakeFakeArguments(module, engine.get()); } -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine) { +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::vector> arguments(params.size()); + std::vector arguments(params.size()); for (int i = 0; i < params.size(); ++i) { arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie(); @@ -417,4 +409,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive, .status(); } +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfig::DEFAULT); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + return absl::make_unique( + shape, lhs, rhs, dot_dimension_numbers, precision_config); +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 7790737c093ad8e5a15c017e3f7890b6f25cb6f8..b3c8a739058475a4e51bae6ad2a98152a6532b47 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/stream_executor/platform.h" namespace xla { @@ -57,8 +57,8 @@ class PseudorandomGenerator { // Generates fake data in a literal of the given shape, or returns an error // status if the element type is currently unhandled for fake data // generation. See below for documentation of pseudo_random. -StatusOr> MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true); +StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. @@ -84,20 +84,26 @@ StatusOr> MakeFakeLiteral(const Shape& shape, // TODO(b/79942829): Make interesting argument generation fast enough that using // pseudo_random does not save any noticeable amount of time so that the // parameter can be removed. -StatusOr>> MakeFakeArguments( - HloModule* const module, bool pseudo_random = true); +StatusOr> MakeFakeArguments(HloModule* const module, + bool pseudo_random = true); // Overload which accepts a random number generator. This enables generation of // different random values with sequential calls to MakeFakeArguments by reusing // the same generator. -StatusOr>> MakeFakeArguments( - HloModule* const module, std::minstd_rand0* engine); +StatusOr> MakeFakeArguments(HloModule* const module, + std::minstd_rand0* engine); // Check that a given module satisfies various constraints before trying to // execute it. Status VerifyHloModule(HloModule* const module, bool layout_sensitive, bool allow_mixed_precision); +// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of +// the LHS with dimension 0 of the RHS with no batch dimensions. +// Both LHS and the RHS must be of rank 2. +std::unique_ptr CreateCanonicalDot(const Shape& shape, + HloInstruction* lhs, + HloInstruction* rhs); } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 322c8ef090cf867f65cada5cb1dbae188f83bad6..181e5cbe290b0df0cf605cc4ef4b8a4945b3d367 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -85,10 +85,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 3); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -114,10 +114,10 @@ XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) })") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 5); - const Literal& index_arg = *args[0]; + const Literal& index_arg = args[0]; EXPECT_EQ(index_arg.Get({0}), 0); @@ -140,10 +140,10 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const float& value : key_arg.data()) { @@ -163,10 +163,10 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( } )") .ValueOrDie(); - TF_ASSERT_OK_AND_ASSIGN(std::vector> args, + TF_ASSERT_OK_AND_ASSIGN(std::vector args, MakeFakeArguments(module.get())); ASSERT_EQ(args.size(), 2); - const Literal& key_arg = *args[0]; + const Literal& key_arg = args[0]; tensorflow::gtl::FlatSet key_set; for (const int32& value : key_arg.data()) { diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index c7eb9e2dbe0e27b7933f5861280a3401cd268c08..b34fd0f2e873214c509533f29553af914ddc984d 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -34,9 +34,8 @@ XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, TokenTree) { @@ -50,9 +49,8 @@ XLA_TEST_F(TokenHloTest, TokenTree) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {})); - EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken())); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {})); + EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken())); } XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { @@ -193,9 +191,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(true); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(42, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(42, result.Get({})); } { @@ -204,9 +201,8 @@ ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { std::unique_ptr module, HloRunner::CreateModuleFromString(module_string, debug_options)); auto arg = LiteralUtil::CreateR0(false); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - Execute(std::move(module), {arg.get()})); - EXPECT_EQ(7, result->Get({})); + TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg})); + EXPECT_EQ(7, result.Get({})); } } diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 125513ddfd16cb4e742e7d589e22b721307621ee..d6641d257a75945be94d299a1bd4b0366e3759b7 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -69,90 +69,90 @@ class TransferManagerTest : public LocalClientTestBase { }; XLA_TEST_F(TransferManagerTest, TransferR0U32) { - std::unique_ptr literal = LiteralUtil::CreateR0(42); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR0(42); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR0Equal(42, *result); + LiteralTestUtil::ExpectR0Equal(42, result); } XLA_TEST_F(TransferManagerTest, TransferR1F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, - *result); + result); } XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { std::vector test_vector(1024 * 1024); std::iota(test_vector.begin(), test_vector.end(), 0); - std::unique_ptr literal = LiteralUtil::CreateR1(test_vector); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1(test_vector); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - LiteralTestUtil::ExpectR1Equal(test_vector, *result); + LiteralTestUtil::ExpectR1Equal(test_vector, result); } XLA_TEST_F(TransferManagerTest, TransferR1U8) { const char* test_string = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(test_string); - const Shape& shape = literal->shape(); + Literal literal = LiteralUtil::CreateR1U8(test_string); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_EQ(result->GetR1U8AsString(), test_string); + EXPECT_EQ(result.GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { - std::unique_ptr literal = + Literal literal = LiteralUtil::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - const Shape& shape = literal->shape(); + const Shape& shape = literal.shape(); auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferR2F32AndChangeLayoutTransferringToDevice) { - std::unique_ptr literal = LiteralUtil::CreateR2WithLayout( + Literal literal = LiteralUtil::CreateR2WithLayout( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); const Shape ondevice_shape = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -160,101 +160,99 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( - LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LayoutUtil::Equal(result.shape().layout(), literal.shape().layout())); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result); } XLA_TEST_F(TransferManagerTest, TransferTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple({}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { - std::unique_ptr literal = LiteralUtil::CreateR1( + Literal literal = LiteralUtil::CreateR1( {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { - std::unique_ptr literal = LiteralUtil::MakeTuple( + Literal literal = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR1( - {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) - .get(), - LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}).get(), - LiteralUtil::CreateR0(complex64(0.3f, -0.4f)).get()}); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}), + LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6}), + LiteralUtil::CreateR0(complex64(0.3f, -0.4f))}); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { @@ -264,54 +262,52 @@ XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) { // supported. auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape()); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); - EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result)); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result)); } XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { const int64 kIterationCount = 5000; - std::unique_ptr literal1 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(123.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-10.0f, 123.0f}).get()}); - std::unique_ptr literal2 = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(456.0f).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), - LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) - .get(), - LiteralUtil::CreateR1({-98.0f, 153.0f}).get()}); - - auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); - auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + Literal literal1 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(123.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}), + LiteralUtil::CreateR1({44.0f, -10.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-10.0f, 123.0f})}); + Literal literal2 = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(456.0f), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}), + LiteralUtil::CreateR1({44.0f, -11.0f, 3333333.3f})}), + LiteralUtil::CreateR1({-98.0f, 153.0f})}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1.shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2.shape()); auto stream1 = stream_; auto stream2 = stream_->GetOrCreateSubStream(); - std::unique_ptr result1, result2; + Literal result1, result2; // Round trip literals through device in multiple streams asynchronously. for (int i = 0; i < kIterationCount; ++i) { - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1, device_buffer1)); - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2, device_buffer2)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result1, + Literal this_result1, transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr this_result2, + Literal this_result2, transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); result1 = std::move(this_result1); result2 = std::move(this_result2); } - EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); - EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2)); } class TransferDeviceToHostBenchmark : public TransferManagerTest { @@ -323,20 +319,19 @@ class TransferDeviceToHostBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr result, + Literal result, transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); } tensorflow::testing::StopTiming(); @@ -355,17 +350,16 @@ class TransferHostToDeviceBenchmark : public TransferManagerTest { tensorflow::testing::StopTiming(); SetUp(); - std::vector> tuple_elements; + std::vector tuple_elements; for (int i = 0; i < num_tuple_elements; ++i) { tuple_elements.push_back( LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); } - std::unique_ptr literal = - LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); - auto device_buffer = AllocateDeviceBuffer(literal->shape()); + Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal.shape()); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal, device_buffer)); } tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index f2b3b49015c7d74d786f63776abff1d5181fd961..619d2a388b5646c31f0a61f709a2ab3067e39c03 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -51,13 +51,13 @@ XLA_TEST_F(TupleTest, TupleConstant) { {1.1f, 2.2f, 3.5f}, // row 0 {4.8f, 5.0f, 6.7f}, // row 1 }; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests a tuple made of scalar constants. @@ -66,12 +66,12 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) { const float constant_scalar1 = 7.3f; const float constant_scalar2 = 1.2f; - auto value = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar1).get(), - LiteralUtil::CreateR0(constant_scalar2).get()}); + auto value = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar1), + LiteralUtil::CreateR0(constant_scalar2)}); - ConstantLiteral(&builder, *value); - ComputeAndCompareTuple(&builder, *value, {}, error_spec_); + ConstantLiteral(&builder, value); + ComputeAndCompareTuple(&builder, value, {}, error_spec_); } // Tests the creation of tuple data. @@ -88,11 +88,11 @@ XLA_TEST_F(TupleTest, TupleCreate) { ConstantR1(&builder, constant_vector), ConstantR2(&builder, constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(constant_scalar), + LiteralUtil::CreateR1(constant_vector), + LiteralUtil::CreateR2(constant_matrix)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of tuple data. @@ -102,10 +102,9 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { Tuple(&builder, {ConstantR0(&builder, 7.0), ConstantR1(&builder, {})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(7.0), LiteralUtil::CreateR1({})}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Tests the creation of an empty tuple. @@ -113,7 +112,7 @@ XLA_TEST_F(TupleTest, EmptyTupleCreate) { XlaBuilder builder(TestName()); Tuple(&builder, {}); auto expected = LiteralUtil::MakeTuple({}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } // Trivial test for extracting a tuple element with GetTupleElement. @@ -196,10 +195,10 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ConstantR2(&builder, constant_matrix)}); Tuple(&builder, {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR2(constant_matrix), + LiteralUtil::CreateR1(constant_vector)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { @@ -218,11 +217,11 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true} auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false} Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(direction).get(), - LiteralUtil::CreateR0(!direction).get()}); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(direction), + LiteralUtil::CreateR0(!direction)}); - ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()}, error_spec_); } } @@ -287,10 +286,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, TuplesInAMap) { @@ -332,10 +330,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { ConstantR1(&builder, vec1)}); Select(ConstantR0(&builder, true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec1), LiteralUtil::CreateR1(vec2)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { @@ -408,10 +405,9 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { Select(ConstantR0(&builder, false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1(vec2), LiteralUtil::CreateR1(vec1)}); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, NestedTuples) { @@ -423,12 +419,11 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); auto expected_s = LiteralUtil::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); + LiteralUtil::MakeTuple({&expected_v1, &expected_s}); auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); - auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2}); - ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); + ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { @@ -446,14 +441,12 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( - { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), - }) - .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::MakeTupleFromSlices({ + LiteralUtil::CreateR1({1.0, 2.0, 3.0}), + LiteralUtil::CreateR1({4.0, 5.0, 6.0}), + }), + LiteralUtil::CreateR1({7.0, 8.0, 9.0}), })) .ConsumeValueOrDie(); @@ -484,40 +477,36 @@ XLA_TEST_F(TupleTest, ComplexTuples) { std::unique_ptr arg0 = client_ - ->TransferToServer(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0({1, 2}).get(), - LiteralUtil::MakeTuple( - {LiteralUtil::CreateR1({{10, 20}, {30, 40}}) - .get(), + ->TransferToServer(LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0({1, 2}), + LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR1({{10, 20}, {30, 40}}), LiteralUtil::CreateR2( {{{100, 200}, {300, 400}}, {{1000, 2000}, {3000, 4000}}, - {{10000, 20000}, {30000, 40000}}}) - .get()}) - .get()})) + {{10000, 20000}, {30000, 40000}}})})})) .ConsumeValueOrDie(); std::unique_ptr arg1 = client_ ->TransferToServer( - *LiteralUtil::CreateR1({{1, 2}, {1, -2}})) + LiteralUtil::CreateR1({{1, 2}, {1, -2}})) .ConsumeValueOrDie(); auto sum = LiteralUtil::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = absl::make_unique(sum->shape()); - ASSERT_TRUE(prod->Populate( - [&sum](absl::Span indexes) { - return sum->Get(indexes) * - (indexes[indexes.size() - 1] == 0 - ? complex64(1, 2) - : complex64(1, -2)); - }) + Literal prod(sum.shape()); + ASSERT_TRUE(prod.Populate([&sum](absl::Span indexes) { + return sum.Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) .ok()); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(), - LiteralUtil::CreateR0({123, 456}).get()}); - ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + auto expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::MakeTupleFromSlices({prod, sum}), + LiteralUtil::CreateR0({123, 456})}); + ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()}, error_spec_); } @@ -541,10 +530,10 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { .ValueOrDie(); auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({1, 2, 3})); - auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); + auto result = ExecuteNoHloPasses(std::move(module), {¶m}); EXPECT_TRUE(LiteralTestUtil::Equal( - *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), - *result)); + LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2({{1, 2, 3}})), + result)); } // Disabled on interpreter due to lack of outfeed. @@ -581,16 +570,15 @@ XLA_TEST_F(TupleHloTest, tensorflow::Env::Default()->StartThread( tensorflow::ThreadOptions(), "execute_thread", [&] { TF_EXPECT_OK(Execute(std::move(module), - {param0.get(), param1.get(), param1.get(), - param0.get(), param4.get()}) + {¶m0, ¶m1, ¶m1, ¶m0, ¶m4}) .status()); })); auto expected = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3})); - auto literal = Literal::CreateFromShape(expected->shape()); + auto literal = Literal::CreateFromShape(expected.shape()); TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( - backend().default_stream_executor(), expected->shape(), *literal)); - EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal)); + backend().default_stream_executor(), expected.shape(), literal)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 8f80a9f3e466d73f2b718452d9a0d64a80c3b36f..4fbd7f2fb174ac899c1e3b23801986cb52db96a2 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -100,9 +100,9 @@ void UnaryOpTest::AbsTestHelper() { {-inf(), 0}}); Abs(arg); - std::unique_ptr expected = + Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -113,9 +113,9 @@ void UnaryOpTest::SignTestHelper() { {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); - std::unique_ptr expected = LiteralUtil::CreateR1( + Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> @@ -127,9 +127,8 @@ void UnaryOpTest::SignAbsTestHelper() { auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); - std::unique_ptr expected = - LiteralUtil::CreateR1({0, 0, 0, 0}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { @@ -172,9 +171,8 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); - std::unique_ptr expected = - LiteralUtil::CreateR0({-2.6f, 0.8f}); - ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); + Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1bdf1867b9330b715b0ba4aca71d56307883c775..7abd8651d5ca272f9e82d797870a3bd6b1589615 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -348,9 +348,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { // have all reached 2.0. auto expected_data = LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}); - auto expected = LiteralUtil::MakeTuple({expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { @@ -401,11 +401,10 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { auto expected_w1 = LiteralUtil::CreateR1({1.0f, 1.0f, 1.0f}); auto expected_w2 = LiteralUtil::CreateR1({2.0f, 2.0f, 2.0f}); auto expected_w3 = LiteralUtil::CreateR1({3.0f, 3.0f, 3.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(), - expected_w3.get(), expected_w1.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple( + {&expected_counter, &expected_w2, &expected_w3, &expected_w1}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { @@ -510,10 +509,9 @@ TEST_F(WhileTest, WhileWithTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(WhileTest, WhileWithPredicateTupleResult) { @@ -557,9 +555,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_predicate = LiteralUtil::CreateR0(true); - auto expected = LiteralUtil::MakeTuple( - {expected_counter.get(), expected_predicate.get()}); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); + auto expected = + LiteralUtil::MakeTuple({&expected_counter, &expected_predicate}); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0)); } TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { @@ -602,10 +600,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR0(7); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests two while nodes when the result type T is a Tuple and the second @@ -886,10 +883,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { auto expected_counter = LiteralUtil::CreateR0(5); auto expected_data = LiteralUtil::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); - auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); - VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); - ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); + auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape()); + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001)); } // Tests a while node when the result type T is a vector of S32. @@ -977,11 +973,11 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto expected_element = LiteralUtil::CreateR1({1, 1}); auto expected = - LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()}); + LiteralUtil::MakeTuple({&expected_element, &expected_element}); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1005,7 +1001,7 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR1({42, 42}))); + client_->TransferToServer(LiteralUtil::CreateR1({42, 42}))); ComputeAndCompareR1(&outer, {1.0f, 1.0f}, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1031,7 +1027,7 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(42))); + client_->TransferToServer(LiteralUtil::CreateR0(42))); ComputeAndCompareR0(&outer, 43.0f, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1070,12 +1066,12 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr parameter_data, - client_->TransferToServer(*LiteralUtil::CreateR0(1))); + client_->TransferToServer(LiteralUtil::CreateR0(1))); auto add1 = LiteralUtil::CreateR0(15); auto add2 = LiteralUtil::CreateR0(16); - auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()}); - ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + auto expected = LiteralUtil::MakeTuple({&add1, &add2}); + ComputeAndCompareTuple(&outer, expected, {parameter_data.get()}, ErrorSpec(1e-6)); } @@ -1228,7 +1224,7 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) { GetTupleElement(while_instruction, 3); TF_ASSERT_OK_AND_ASSIGN( - auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2( + auto param_value, client_->TransferToServer(LiteralUtil::CreateR2( {{1.0, 2.0}, {-1.0, -2.0}}))); ComputeAndCompareR2( @@ -1258,9 +1254,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) { XlaBuilder builder(TestName()); While(condition, body, ConstantR0(&builder, 0)); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(true))); - TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0(false))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(true))); + TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0(false))); ComputeAndCompareR0(&builder, 2, {}); } diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 7fd42944debe38abbf6f0ca36bc5c7ecb1aeaf97..db5a824de08edeb81b5deb047507dc6158833008 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -144,14 +144,14 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 442e66321ee732f3d9cdfe4931433bd864b7fa82..cdde88c1359416d423685f330e9cbdf77948034f 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -39,8 +39,7 @@ limitations under the License. namespace xla { -StatusOr> TextLiteralReader::ReadPath( - absl::string_view path) { +StatusOr TextLiteralReader::ReadPath(absl::string_view path) { CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; @@ -57,7 +56,7 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -StatusOr> TextLiteralReader::ReadAllLines() { +StatusOr TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); string shape_string; @@ -74,9 +73,9 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape)); } - auto result = absl::make_unique(shape); + Literal result(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill); + result.PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -116,7 +115,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line); } - result->Set(coordinate_values, value); + result.Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index b265640802c88847ce57e9f942f9f0859b873ae8..c40b43279f56fbd6e8ec91cc45c1f8e4cac8b5ef 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -41,7 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath(absl::string_view path); + static StatusOr ReadPath(absl::string_view path); private: // Ownership of file is transferred. @@ -49,7 +49,7 @@ class TextLiteralReader { // Parses a shape string on the first line, followed by lines of values to the // end of the file. - StatusOr> ReadAllLines(); + StatusOr ReadAllLines(); // Owns the file being read std::unique_ptr file_; diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index 92f9b4f9f0efa2dc08287bdcbefc88f879164308..1fab4e3a08dd3d76a6efeaabe7bf8ab96892e638 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -42,16 +42,15 @@ TEST(TextLiteralReaderTest, ReadsR3File) { tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents) .ok()); - std::unique_ptr literal = - TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); + Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( - ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, literal->Get({0, 0, 0})); - EXPECT_EQ(43.5, literal->Get({0, 0, 1})); - EXPECT_EQ(44.5, literal->Get({0, 0, 2})); - EXPECT_EQ(45.5, literal->Get({0, 1, 0})); - EXPECT_EQ(46.5, literal->Get({0, 1, 1})); - EXPECT_EQ(47.5, literal->Get({0, 1, 2})); + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape())); + EXPECT_EQ(42.5, literal.Get({0, 0, 0})); + EXPECT_EQ(43.5, literal.Get({0, 0, 1})); + EXPECT_EQ(44.5, literal.Get({0, 0, 2})); + EXPECT_EQ(45.5, literal.Get({0, 1, 0})); + EXPECT_EQ(46.5, literal.Get({0, 1, 1})); + EXPECT_EQ(47.5, literal.Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 4ea02faffcd52065b05c0444202bd1a3d9d87ee6..5cbaf2fcc192c48092272094710ccaf5c9cf9616 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -37,7 +37,7 @@ TEST(TextLiteralWriterTest, WritesFloatLiteral) { }); string path = tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever"); - ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path)); + ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path)); string contents; TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path, &contents)); diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index 23ce1d235b9f2613505f8a3bfbd1a4c1162debd4..0c3ec5934e546f551089f715dbbe6f4479e56c3c 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -67,8 +67,8 @@ int main(int argc, char** argv) { floats.push_back(value); } - absl::string_view content(absl::bit_cast(floats.data()), - floats.size() * sizeof(float)); + tensorflow::StringPiece content(absl::bit_cast(floats.data()), + floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output_file, content)); return 0; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ba814af4769f43dbe96190c902cf6f52ca5659bb..0c41f227b31ebe1f01073785ea2a666093aefdb3 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -121,11 +121,10 @@ StatusOr ReplayComputation(const HloSnapshot& module, } } else { // use recorded data if available for (const auto& proto : module.arguments()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - Literal::CreateFromProto(proto)); + TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer data, - client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0)); scoped_shaped_buffer_arguments.push_back(std::move(data)); } for (const auto& argument : scoped_shaped_buffer_arguments) { @@ -161,12 +160,12 @@ StatusOr ReplayComputation(const HloSnapshot& module, // --generate_fake_infeed is passed and there exists an infeed operation in // the HloSnapshot. absl::optional pool; - std::unique_ptr data; + Literal data; if (provide_infeed) { data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); } auto transfer_infeed = [&data, client]() { - TF_CHECK_OK(client->TransferToInfeed(*data)); + TF_CHECK_OK(client->TransferToInfeed(data)); }; if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", @@ -214,9 +213,9 @@ StatusOr ReplayComputation(const HloSnapshot& module, << "s: " << module.hlo().hlo_module().name(); } - TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + TF_ASSIGN_OR_RETURN(Literal result_literal, client->ShapedBufferToLiteral(*result)); - return std::move(*result_literal); + return result_literal; } StatusOr ParseInputFile(const string& filename, @@ -305,11 +304,11 @@ int RealMain(absl::Span args, const Options& opts) { result.ToString().c_str()); auto& snapshot = snapshots[i]; if (snapshot.has_result()) { - std::unique_ptr literal = + Literal literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(snapshot.result().shape()).c_str(), - literal->ToString().c_str()); + literal.ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index 51909190a3ef20c3df78d08796e88bdbb650609d..4f8852f8c11fb749ef851bc4faf176fcc5cb3524 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,8 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - std::unique_ptr literal = + xla::Literal literal = xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal->ToString().c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 48c837481181f6ad8f864569fd62e0e23fa02ecd..4b5c276bdf66f3dc5364aae4654b13a625b0a4f7 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -36,16 +36,16 @@ int main(int argc, char **argv) { LOG(QFATAL) << "Usage: " << argv[0] << " "; } - std::unique_ptr literal = + xla::Literal literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << *literal; - fprintf(stderr, "%s\n", literal->ToString().c_str()); - if (literal->shape().element_type() == xla::F32) { - float min = *std::min_element(literal->data().begin(), - literal->data().end()); - float max = *std::max_element(literal->data().begin(), - literal->data().end()); + LOG(INFO) << "literal: " << literal; + fprintf(stderr, "%s\n", literal.ToString().c_str()); + if (literal.shape().element_type() == xla::F32) { + float min = *std::min_element(literal.data().begin(), + literal.data().end()); + float max = *std::max_element(literal.data().begin(), + literal.data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 8e43f275e10408f1ed2b84b031a8316a94de3a82..73b3589dbf12341ddb3f3e819a550467a7b4d166 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -351,6 +351,7 @@ message DeviceAssignmentProto { message LiteralProto { Shape shape = 1; repeated bool preds = 2; + bytes s8s = 15; bytes u8s = 3; repeated int32 s32s = 4; repeated int64 s64s = 5; @@ -364,7 +365,7 @@ message LiteralProto { bytes f16s = 11; bytes bf16s = 13; repeated int64 sparse_indices = 14; - // Next = 15 + // Next = 16 } message WindowDimension { @@ -580,7 +581,7 @@ message SourceTarget { // Used to indicate the precision configuration. It has backend specific // meaning. -message PrecisionConfigProto { +message PrecisionConfig { enum Precision { DEFAULT = 0; HIGH = 1; diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index efbe9802784771618b46c08f24af46c8664001e7..2ff97914f862e0ec30fc54602ec5fee2a0a5ebca 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 68ba17a424cf5d204eb780e495580efe60ca863c..9e3d2454d16730c1d1f93cb384db88544380f77e 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -46,19 +46,15 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -67,6 +63,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor:stream_executor_headers_lib", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 5cf2bc886177a3ac521b412b894628e6ec4eba42..1d4f8d97f2ed8b263878b94b365b7fb5b949b1a2 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/types.h" @@ -70,7 +70,7 @@ Status CompilationCacheKey(const xrt::XLAComputation& computation, string serialized; TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized)); uint64 fingerprint = Fingerprint64(serialized); - *key = strings::StrCat(fingerprint); + *key = absl::StrCat(fingerprint); return Status::OK(); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 478c9663a7641ba2bf22e9119212ee8ef8947d4f..54b06558adcd8ef1f8f1bee52d210d558801afea 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -49,7 +49,7 @@ class XRTStateHelpers { // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an // OpKernel::Compute method. static Status MakeLiteral(const xla::LiteralProto& proto, - std::unique_ptr* literal) { + xla::Literal* literal) { TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto)); return Status::OK(); } @@ -173,7 +173,7 @@ class XRTAllocateOp : public OpKernel { errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal)); @@ -189,7 +189,7 @@ class XRTAllocateOp : public OpKernel { XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( - *literal, device_ref.backend(), + literal, device_ref.backend(), device_ref.device_ordinal(), &allocation)); // Intern takes ownership of our reference to allocation. @@ -381,11 +381,11 @@ class XRTReadLiteralOp : public OpKernel { OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef( ctx, allocation->device_ordinal(), &device_ref)); - std::unique_ptr literal; + xla::Literal literal; OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); - xla::LiteralProto literal_proto = literal->ToProto(); + xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); literal_proto.SerializeToString(&output.scalar()()); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 5b8516bf1dceb4ffa37a8fb52fb287281a661e9d..2952feb16a8a60aecf16be87c9b800d314c4af58 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -52,44 +52,44 @@ string DeviceFromFlag() { xla::LiteralProto TwoElementTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - return tuple->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + return tuple.ToProto(); } xla::LiteralProto ScalarLiteral() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); - return scalar->ToProto(); + return scalar.ToProto(); } xla::LiteralProto NestedTuple() { auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); auto scalar = xla::LiteralUtil::CreateR0(12.0f); - auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()}); - return nested->ToProto(); + auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar}); + return nested.ToProto(); } xla::LiteralProto MakeTuple0() { auto scalar = xla::LiteralUtil::CreateR0(12.0f); auto array = xla::LiteralUtil::CreateR1({1.0f, 3.0f}); auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}); - auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()}); - auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()}); - auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()}); - return nested1->ToProto(); + auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix}); + auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple}); + auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0}); + return nested1.ToProto(); } -xla::LiteralProto FloatVector(gtl::ArraySlice v) { +xla::LiteralProto FloatVector(absl::Span v) { auto array = xla::LiteralUtil::CreateR1(v); - return array->ToProto(); + return array.ToProto(); } bool CompareLiteralProtos(const xla::LiteralProto& a, const xla::LiteralProto& b) { auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie(); auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = *l_a == *l_b; + bool equal = l_a == l_b; if (!equal) { LOG(INFO) << "LiteralProtos don't match " << a.DebugString() << " != " << b.DebugString(); @@ -100,7 +100,7 @@ bool CompareLiteralProtos(const xla::LiteralProto& a, bool CompareLiteralToLiteralProto(const xla::Literal& a, const xla::LiteralProto& b) { auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie(); - bool equal = a == *l_b; + bool equal = a == l_b; if (!equal) { LOG(INFO) << "Literal and LiteralProto don't match " << a.ToProto().DebugString() << " != " << b.DebugString(); @@ -211,7 +211,7 @@ TEST(RawApiTest, SubBuffer) { TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs)); auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie(); - auto base_elements = base_literal->DecomposeTuple(); + auto base_elements = base_literal.DecomposeTuple(); auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); xla::LiteralProto response_0; EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); @@ -343,7 +343,7 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } TEST(RawApiTest, CompileAndExecuteReturnTuple) { @@ -392,8 +392,8 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); - auto expected = xla::LiteralUtil::MakeTuple({sum.get()}); - EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response)); + auto expected = xla::LiteralUtil::MakeTuple({&sum}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } } // namespace diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 911ac9a78b7c7477f620f47d7fc79f9196a86469..d05a1e7dcbff440e0daf03bd25535c26d82b6a0b 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -174,7 +174,7 @@ XRTTupleAllocation::~XRTTupleAllocation() { } Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal) { + xla::Literal* literal) { auto transfer_manager = backend->transfer_manager(); TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal)); TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice( @@ -201,14 +201,14 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key, XRTTupleAllocation** allocation) { - string key_string = strings::StrCat(key); + string key_string = absl::StrCat(key); TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation)); return Status::OK(); } /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm, int64 key) { - string key_string = strings::StrCat(key); + string key_string = absl::StrCat(key); return rm->Delete(kTupleContainer, key_string); } @@ -410,7 +410,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) { *key = get_uid(); - string key_string = strings::StrCat(*key); + string key_string = absl::StrCat(*key); return rm->Create(kTupleContainer, key_string, this); } diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h index 42705688ddfeb21aa734cccfce36c8d11d0d60a9..73b5584e38f781343fe6793af7ad28232fbfc184 100644 --- a/tensorflow/compiler/xrt/xrt_state.h +++ b/tensorflow/compiler/xrt/xrt_state.h @@ -135,7 +135,7 @@ class XRTTupleAllocation : public ResourceBase { // Copies the allocation from device to host and returns it in literal. Status ToLiteral(xla::Backend* backend, int device_ordinal, - std::unique_ptr* literal); + xla::Literal* literal); // True if none of the buffers in the allocation are aliased by any other live // handle. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 66983801bf81188f81b9d4149eec5f0d20a296b4..d98a24994cbf080184fe47111a718f31b7a64f0b 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -20,13 +20,7 @@ py_library( ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = if_not_windows([ - # TODO(aaroey): tensorrt dependency has to appear before tflite so the - # build can resolve its flatbuffers symbols within the tensorrt library. - # This is an issue with the tensorrt static library and will be fixed by - # the next tensorrt release, so fix the order here after that. - "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows - ]) + [ + deps = [ "//tensorflow/contrib/all_reduce", "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", @@ -135,6 +129,7 @@ py_library( ]) + if_not_windows([ "//tensorflow/contrib/bigtable", # depends on bigtable "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows + "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ]), ) @@ -171,7 +166,9 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_kernels", ], "//conditions:default": [], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", + ]), ) cc_library( @@ -208,5 +205,7 @@ cc_library( "//tensorflow/contrib/kinesis:dataset_ops_op_lib", ], "//conditions:default": [], - }), + }) + if_not_windows([ + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", + ]), ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 5f477a79a3d960bc2cd2df2d288ae80e30671d75..9478e42b46f363c9ad673ade1ea1ceff27075ff0 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,6 +21,14 @@ from __future__ import print_function import os +from tensorflow.python.tools import component_api_helper +component_api_helper.package_hook( + parent_package_str=( + "tensorflow.contrib"), + child_package_str=( + "tensorflow_estimator.contrib.estimator")) +del component_api_helper + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index ad700ac4a0342e2a7bc07a6ecf6710cea892e296..e37ad7a7581666e8207d5d35e197be3b3576a24d 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -21,11 +21,9 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + # This module is kept for backward compatibility only. To depend on AutoGraph, + # use //third_party/tensorflow/python/autograph instead. deps = [ - "//tensorflow/contrib/autograph/impl", - "//tensorflow/contrib/autograph/lang", - "//tensorflow/contrib/autograph/pyct", - "//tensorflow/contrib/autograph/utils", - "//tensorflow/python:util", + "//tensorflow/python/autograph", ], ) diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index cc54da4daa9a5bb4e64145963ffec63021d08876..6ea2db72c411f2f19a06ff108d6b63fc3bde352b 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,5 +1,12 @@ # AutoGraph +**NOTE: As tensorflow.contrib is being +[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is +moving into TensorFlow core. + +The new code location is `tensorflow/python/autograph`. +** + IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)). AutoGraph is a Python to TensorFlow compiler. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 26e7a4a4d38e264486c981e6fc4c547bcc53b302..137bc59202b26c1c224fec4c2fca2dec83db13a5 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -12,57 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Autograph compiles Python code into equivalent TensorFlow code. +"""This is the legacy module for AutoGraph, kept for backward compatibility. -Equivalent here means that they have the same effect when executed. +New users should instead use `tensorflow.python.autograph`. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Bring only the relevant symbols to the top level. -from tensorflow.contrib.autograph import operators -from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.core.errors import GraphConstructionError -from tensorflow.contrib.autograph.core.errors import TfRuntimeError -from tensorflow.contrib.autograph.core.errors import improved_errors -from tensorflow.contrib.autograph.impl.api import RunMode -from tensorflow.contrib.autograph.impl.api import convert -from tensorflow.contrib.autograph.impl.api import converted_call -from tensorflow.contrib.autograph.impl.api import do_not_convert -from tensorflow.contrib.autograph.impl.api import to_code -from tensorflow.contrib.autograph.impl.api import to_graph -from tensorflow.contrib.autograph.lang.directives import set_element_type -from tensorflow.contrib.autograph.lang.directives import set_loop_options -from tensorflow.contrib.autograph.lang.special_functions import stack -from tensorflow.contrib.autograph.lang.special_functions import tensor_list -from tensorflow.contrib.autograph.pyct.transformer import AutographParseError -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - # Main API - 'RunMode', - 'convert', - 'converted_call', - 'do_not_convert', - 'to_code', - 'to_graph', - # Overloaded operators - 'operators', - # Errors - 'improved_errors', - 'GraphConstructionError', - 'TfRuntimeError', - # Python language "extensions" - 'set_element_type', - 'set_loop_options', - 'stack', - 'tensor_list', - # Exceptions - 'AutographParseError', - # Utilities: to be removed - 'utils', -] - -remove_undocumented(__name__, _allowed_symbols) +from tensorflow.python.autograph import * # pylint:disable=wildcard-import diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py deleted file mode 100644 index 57b5f747417613a5dd5bce08e4a9e9ef98442cf6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility module that contains APIs usable in the generated code.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin -from tensorflow.contrib.autograph.utils.builtins import dynamic_print -from tensorflow.contrib.autograph.utils.builtins import dynamic_range -from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns -from tensorflow.contrib.autograph.utils.misc import alias_tensors -from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is -from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not -from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond -from tensorflow.contrib.autograph.utils.py_func import wrap_py_func -from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append -from tensorflow.contrib.autograph.utils.testing import fake_tf -from tensorflow.contrib.autograph.utils.type_check import is_tensor diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py deleted file mode 100644 index 4dd440ef197b7e24b901bc9e30794b0182378a32..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Builtin conversion utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import py_func -from tensorflow.contrib.autograph.utils import type_check -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import list_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops - - -def dynamic_builtin(f, *args, **kwargs): - """Converts a builtin function call inline.""" - if f is len: - return dynamic_len(*args, **kwargs) - if six.PY2 and f is xrange: - return dynamic_range(*args, **kwargs) - if f is range: - return dynamic_range(*args, **kwargs) - if f is int: - return dynamic_int(*args, **kwargs) - if f is float: - return dynamic_float(*args, **kwargs) - if f is abs: - return dynamic_abs(*args, **kwargs) - - raise NotImplementedError( - 'The "%s" builtin is not yet supported.' % f.__name__) - - -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if _is_tensor_list(list_or_tensor): - return list_ops.tensor_list_length(list_or_tensor) - elif tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape.ndims: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) - - -def _is_tensor_list(list_or_tensor): - return (tensor_util.is_tensor(list_or_tensor) - and list_or_tensor.dtype == dtypes.variant) - - -def dynamic_int(num_or_tensor, **kwargs): - """Implementation of int() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) - return int(num_or_tensor) - - -def dynamic_float(num_or_tensor, **kwargs): - """Implementation of float() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) - return float(num_or_tensor) - - -def dynamic_abs(num_or_tensor, **kwargs): - if tensor_util.is_tensor(num_or_tensor): - return math_ops.abs(num_or_tensor, **kwargs) - else: - return abs(num_or_tensor, **kwargs) - - -def dynamic_range(start_or_stop, stop=None, step=None): - """Implementation of range using dynamic dispatch.""" - if type_check.is_tensor(start_or_stop, stop, step): - if step is not None: - return math_ops.range(start_or_stop, stop, step) - if stop is not None: - return math_ops.range(start_or_stop, stop) - return math_ops.range(start_or_stop) - - if step is not None: - return range(start_or_stop, stop, step) - elif stop is not None: - return range(start_or_stop, stop) - return range(start_or_stop) - - -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False - - -def dynamic_print(*values): - """Implementation of print using dynamic dispatch. - - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. - - Args: - *values: values to print - Returns: - A dummy value indicating the print completed. If tf. - """ - - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - - def print_wrapper(*vals): - if six.PY3: - # TensorFlow doesn't seem to generate Unicode when passing strings to - # py_func. This causes the print to add a "b'" wrapper to the output, - # which is probably never what you want. - vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) - print(*vals) - # The flush helps avoid garbled output in IPython. - sys.stdout.flush() - - return py_func.wrap_py_func( - print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py deleted file mode 100644 index b1cd5253bc3ffb1e67d89ef79cf56eaeb65fae07..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for builtins module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import builtins -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import test - - -class BuiltinsTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant_op.constant(1) - - with self.assertRaisesRegexp(ValueError, - 'len requires non-zero rank for tensor.*'): - with self.test_session() as sess: - sess.run(builtins.dynamic_builtin(len, a)) - - def test_dynamic_len_tf_array(self): - a = constant_op.constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_abs_tf_scalar(self): - a = constant_op.constant(-1) - - with self.test_session() as sess: - self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) - - def test_dynamic_abs_tf_array(self): - a = constant_op.constant([-1, 2, -3]) - - with self.test_session() as sess: - self.assertListEqual([1, 2, 3], - list(sess.run(builtins.dynamic_builtin(abs, a)))) - - def test_dynamic_abs_py_scalar(self): - a = -1 - self.assertEqual(1, builtins.dynamic_builtin(abs, a)) - - def test_dynamic_len_tf_matrix(self): - a = constant_op.constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, builtins.dynamic_builtin(len, a)) - - def test_dynamic_range_all_python(self): - self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) - self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) - - def test_dynamic_range_tf(self): - with self.test_session() as sess: - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), - [0, 1, 2]) - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), - [1, 2]) - self.assertAllEqual( - sess.run( - builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), - [2, 1]) - - def test_dynamic_range_detection(self): - def range(x): # pylint:disable=redefined-builtin - return x - - # Functions that just have the names of builtins are rejected. - with self.assertRaises(NotImplementedError): - self.assertEqual(builtins.dynamic_builtin(range, 1), 1) - if six.PY2: - self.assertListEqual( - list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) - - def test_casts(self): - i = constant_op.constant(2, dtype=dtypes.int32) - f = constant_op.constant(1.0, dtype=dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) - self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, True), 1) - self.assertEqual(builtins.dynamic_builtin(int, False), 0) - self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) - self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) - - def test_dynamic_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_dynamic_print_complex(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index a25a641cdb4608dee6d6c1bd18697860cc1f5613..6138d7912601344ef7422fd50fb35c8401fd2e63 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -172,6 +172,11 @@ class BigtableTableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), BigtableTableOp); +} // namespace + +namespace data { +namespace { + class ToBigtableOp : public AsyncOpKernel { public: explicit ToBigtableOp(OpKernelConstruction* ctx) @@ -354,5 +359,6 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), ToBigtableOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index a2a5df1037a00ccfdff1910dd950d7b012e684e2..4652021fecabfa11fa6a8754dc884d89e151b590 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -79,6 +79,8 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table table_; }; +namespace data { + // BigtableReaderDatasetIterator is an abstract class for iterators from // datasets that are "readers" (source datasets, not transformation datasets) // that read from Bigtable. @@ -138,6 +140,8 @@ class BigtableReaderDatasetIterator : public DatasetIterator { ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); }; +} // namespace data + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index bd32672aa99d7bf70c44a264f488482c4f213a0b..11f530e82a186f410bc505de7fbf1b478240c340 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { @@ -226,4 +227,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), BigtableLookupDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index a803fdcb49604ef4e596b64d62c7278c69764c15..5cab729d9c16f144ec5671ad775f384ad79ad9e0 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { @@ -111,4 +112,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), BigtablePrefixKeyDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 5cd0371c79f7eded9303b81dd388df8d306dff80..4dc4647bd24f3a957bc93a9ed8c81b3c7deb6a47 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableRangeKeyDatasetOp : public DatasetOpKernel { @@ -117,4 +118,5 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), BigtableRangeKeyDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 6928d9423c84f7504fea3ac1abd929357da034a5..736775bdac10da757190c0b2e4a7672d55edf317 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { @@ -205,4 +206,5 @@ REGISTER_KERNEL_BUILDER( BigtableSampleKeyPairsDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index a759fb5063900199325304ccf83c52f3bdd7d702..208b7b3e08692c00c1fd879c2a02641fb05ff639 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableSampleKeysDatasetOp : public DatasetOpKernel { @@ -118,4 +119,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), BigtableSampleKeysDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 78a920b077680980a209ad8c30c09409a6f4ebf5..9407855fe88db9faec1949db98a725e5a1cd9f38 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableScanDatasetOp : public DatasetOpKernel { @@ -224,4 +225,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), BigtableScanDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 870ce2442bb5e98db7615c43054c9c827b8c88f0..4c7a538b385ec19f520bff79bab20a121221c60f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ head = head_lib.regression_head( label_name=label_name, @@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): center_bias=False, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + Raises: ValueError: If learner_config is not valid. """ @@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a core version of GradientBoostedDecisionTreeEstimator. Args: @@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. """ def _model_fn(features, labels, mode, config): @@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': True, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) @@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): label_keys=None, logits_modifier_function=None, center_bias=False, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): 'use_core_libs': True, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 04b46c3483fa25286078b88c2776b76e4f3c0bcf..a6e422847d3914188bca9e6dff797ba1ffb06749 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -81,6 +81,7 @@ def model_builder(features, logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -116,7 +117,8 @@ def model_builder(features, logits_dimension=head.logits_dimension, features=training_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -237,6 +239,7 @@ def ranking_model_builder(features, output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -299,7 +302,8 @@ def ranking_model_builder(features, logits_dimension=head.logits_dimension, features=main_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): # Logits for inference. diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 1375fddf2bea1a8f856c35d756c38a8beb14a53f..606da663dc2e43688bc42bf6e33a48cd680f54e1 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -296,8 +296,9 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel { int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); QuantileStreamResource* streams_resource; // Create a reference to the underlying resource using the handle. OP_REQUIRES_OK(context, @@ -709,8 +710,9 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel { &buckets_list, stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); QuantileStreamResource* streams_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, &streams_resource)); diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 3b28ed77f325b3f8b09fe6b9d2776eff82ff53a7..51e0c2e431acbea727bc0b2149557d0e30c8c432 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -862,6 +862,15 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(state->feature_column_group_id()); + CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id) + << "Unexpected feature ID selected. " + << "Start feature ID: [" << start_index << "] " + << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1) + << "\nBest feature ID: [" << best_feature_idx << "] " + << feature_ids(best_feature_idx, 0) << ", " + << feature_ids(best_feature_idx, 1) + << "\nPartition IDS: " << partition_ids(start_index) << " " + << partition_ids(best_feature_idx); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index 90a0655201f8cb8df6fc6417cb51216dec91b4d7..e446c411a8d5075563b8f8b912b29df310e16c8c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -448,8 +448,9 @@ class StatsAccumulatorScalarAddOp : public OpKernel { stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); StatsAccumulatorScalarResource* accumulator_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, @@ -512,8 +513,9 @@ class StatsAccumulatorTensorAddOp : public OpKernel { stamp_token](int64 start, int64 end) { for (int resource_handle_idx = start; resource_handle_idx < end; ++resource_handle_idx) { - ResourceHandle handle = resource_handle_list[resource_handle_idx] - .flat()(0); + const ResourceHandle& handle = + resource_handle_list[resource_handle_idx] + .flat()(0); StatsAccumulatorTensorResource* accumulator_resource; OP_REQUIRES_OK(context, LookupResource(context, handle, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index e6407174b1a6557cc101a3485b1a25d12d54a0ae..4da25298cb82093ac501997cc21c48265df06860 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -_BIAS_FEATURE_ID = -1 +_BIAS_FEATURE_ID = int(dtypes.int64.min) class EqualitySplitHandler(base_split_handler.BaseSplitHandler): @@ -141,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # The bias is computed on gradients and hessians (and not # filtered_gradients) which have exactly one value per example, so we # don't double count a gradient in multivalent columns. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. # Bias is added to the stats even if there are no examples with values in diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index d9f03c3840f8edd88174be4e97aaaf7d0efd220b..94ea7bc2eb7b098a0628683167510bf4e3c2426e 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -281,7 +281,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): gains[0], 0.00001) def testGenerateFeatureSplitCandidatesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Feature ID | # i0 | (0.2, 0.12) | 0 | 1,2 | @@ -404,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testGenerateFeatureSplitCandidatesMulticlass(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2]) @@ -482,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(1, split_node.feature_id) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] @@ -530,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = [0, 0, 0, 1] diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 5532bd026ab695d166bc2e2872ecc551920978d5..74b0ea6989c65e83e7a466107d624712a0e72d1b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape): class DenseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -183,7 +183,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testObliviousFeatureSplitGeneration(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 1 | 3 | @@ -320,7 +320,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(2, oblivious_split_info.children_parent_id[1]) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -458,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -546,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52]) # Batch size is 4, 2 gradients per each instance. gradients = array_ops.constant( @@ -633,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 1e-6) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -708,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testGenerateFeatureSplitCandidatesWithTreeComplexity(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -842,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.threshold, 0.00001) def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -951,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testGenerateFeatureSplitCandidates(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1074,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1207,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1302,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self): - with self.test_session() as sess: + with self.cached_session() as sess: # Batch is 4, 2 classes gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]]) @@ -1397,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.52, split_node.split.threshold) def testGenerateFeatureSplitCandidatesInactive(self): - with self.test_session() as sess: + with self.cached_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Sparse Quantile | # i0 | (0.2, 0.12) | 0 | 1 | @@ -1475,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2]) # No values in this feature column in this mini-batch. values = array_ops.constant([], dtype=dtypes.float32) @@ -1545,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): def testEmptyBuckets(self): """Test that reproduces the case when quantile buckets were empty.""" - with self.test_session() as sess: + with self.cached_session() as sess: sparse_column = array_ops.sparse_placeholder(dtypes.float32) # We have two batches - at first, a sparse feature is empty. @@ -1638,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(splits), 0) def testDegenerativeCase(self): - with self.test_session() as sess: + with self.cached_session() as sess: # One data example only, one leaf and thus one quantile bucket.The same # situation is when all examples have the same values. This case was # causing before a failure. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 4278a30ba9d35bc4e57364b63777c01a4508223d..46dfbdefeb00ffa075f7e7b6835b73eb258443d2 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -331,7 +331,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([[], []], dropout_info.eval()) def testObliviousEnsemble(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Bias tree. tree1 = tree_ensemble_config.trees.add() @@ -1399,7 +1399,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([0, 0], result.eval()) def testObliviousTreeNonFinalized(self): - with self.test_session(): + with self.cached_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() # Depth 3 tree. tree1 = tree_ensemble_config.trees.add() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index b3e4c2e5f7a907892d66ad4181eb6ed8589bab6e..86fd5770a033a15df5788d3f74563c82f660371c 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -411,7 +411,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEmptyEnsembleObliviousCase(self): """Test growing an empty ensemble in the oblivious case.""" - with self.test_session() as session: + with self.cached_session() as session: # Create empty ensemble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree_ensemble_handle = model_ops.tree_ensemble_variable( @@ -1620,7 +1620,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleTreeLayerByLayerObliviousCase(self): """Test growing an existing ensemble with the last tree not finalized.""" - with self.test_session() as session: + with self.cached_session() as session: # Create existing ensemble with one root split tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( @@ -1810,7 +1810,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesMiddleCase(self): """Test case: The middle existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ @@ -2071,7 +2071,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): def testGrowEnsembleWithEmptyNodesBorderCase(self): """Test case: The first and last existing leaves don't have examples.""" - with self.test_session() as session: + with self.cached_session() as session: tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge( """ diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b008c6e5346980d926c851919bfc28ecced266b5..c7eb2493a8ba56943740326cf68ad6b3a91f67c4 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object): feature_columns=None, use_core_columns=False, output_leaf_index=False, - output_leaf_index_modes=None): + output_leaf_index_modes=None, + num_quantiles=100): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object): output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which dictates when leaf indices will be outputted. By default, leaf indices are only outputted in INFER mode. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: if inputs are not valid. @@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._num_quantiles = num_quantiles self._max_tree_depth = variables.Variable( initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( @@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object): loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) weak_learner_type = constant_op.constant( self._learner_config.weak_learner_type) - epsilon = 0.01 - num_quantiles = 100 + num_quantiles = self._num_quantiles + epsilon = 1.0 / num_quantiles strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 150d734db6cdd8023ab6d91a49872f657bcdbdea..94b7f4f867655bf7fdf94e8488eeae7088c41622 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -37,6 +37,7 @@ Checkpoint management: Saving and restoring Python state: @@NumpyState +@@PythonStateWrapper """ from __future__ import absolute_import @@ -45,6 +46,7 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker from tensorflow.contrib.checkpoint.python.python_state import NumpyState +from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py index 9b11035b6d277851ea0a0071062bf5cf6b6b2185..302d5cfb79a08b6adf52ebd44533152c5454eadc 100644 --- a/tensorflow/contrib/checkpoint/python/python_state.py +++ b/tensorflow/contrib/checkpoint/python/python_state.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import functools +import six import numpy @@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase): # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making # ndarrays checkpointable natively and using standard checkpointable list # tracking. - if isinstance(value, numpy.ndarray): + if isinstance(value, (numpy.ndarray, numpy.generic)): try: existing = super(NumpyState, self).__getattribute__(name) existing.array = value @@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase): super(NumpyState, self).__setattr__(name, value) -class _NumpyWrapper(base.CheckpointableBase): +@six.add_metaclass(abc.ABCMeta) +class PythonStateWrapper(base.CheckpointableBase): + """Wraps a Python object for storage in an object-based checkpoint.""" + + @abc.abstractmethod + def _serialize(self): + """Callback for `PythonStringStateSaveable` to serialize the object.""" + + @abc.abstractmethod + def _deserialize(self, string_value): + """Callback for `PythonStringStateSaveable` to deserialize the object.""" + + def _gather_saveables_for_checkpoint(self): + """Specify callbacks for saving and restoring `array`.""" + return { + "py_state": functools.partial( + base.PythonStringStateSaveable, + state_callback=self._serialize, + restore_callback=self._deserialize) + } + + +class _NumpyWrapper(PythonStateWrapper): """Wraps a NumPy array for storage in an object-based checkpoint.""" def __init__(self, array): @@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase): self.array = array def _serialize(self): - """Callback for `PythonStringStateSaveable` to serialize the array.""" + """Callback to serialize the array.""" string_file = BytesIO() try: numpy.save(string_file, self.array, allow_pickle=False) @@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase): return serialized def _deserialize(self, string_value): - """Callback for `PythonStringStateSaveable` to deserialize the array.""" + """Callback to deserialize the array.""" string_file = BytesIO(string_value) try: self.array = numpy.load(string_file, allow_pickle=False) finally: string_file.close() - def _gather_saveables_for_checkpoint(self): - """Specify callbacks for saving and restoring `array`.""" - return { - "array": functools.partial( - base.PythonStringStateSaveable, - state_callback=self._serialize, - restore_callback=self._deserialize) - } diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py index 0439a4755e36fc3be6e065d18d3e835feda8aab3..45494351ff4e6c8c75634d8563c3fb63c6089036 100644 --- a/tensorflow/contrib/checkpoint/python/python_state_test.py +++ b/tensorflow/contrib/checkpoint/python/python_state_test.py @@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase): save_state.a = numpy.ones([2, 2]) save_state.b = numpy.ones([2, 2]) save_state.b = numpy.zeros([2, 2]) + save_state.c = numpy.int64(3) self.assertAllEqual(numpy.ones([2, 2]), save_state.a) self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) + self.assertEqual(3, save_state.c) first_save_path = saver.save(prefix) save_state.a[1, 1] = 2. + save_state.c = numpy.int64(4) second_save_path = saver.save(prefix) load_state = python_state.NumpyState() @@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase): loader.restore(first_save_path).initialize_or_restore() self.assertAllEqual(numpy.ones([2, 2]), load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(3, load_state.c) load_state.a[0, 0] = 42. self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) loader.restore(first_save_path).run_restore_ops() @@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase): loader.restore(second_save_path).run_restore_ops() self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) + self.assertEqual(4, load_state.c) def testNoGraphPollution(self): graph = ops.Graph() diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py index 493b3c6f1b5e7a7a7dc1dd4f48d2f54c1d284098..11e177cd0c81f99bd6e00eac4de90a46fb9f64f0 100644 --- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py @@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase): def _ReadAndCheckRowsUsingFeatures(self, num_rows): self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: feature_configs = { "int64_col": parsing_ops.FixedLenFeature( @@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase): num_rows = 10 self.server.handler.num_rows = num_rows - with self.test_session() as sess: + with self.cached_session() as sess: reader = cloud.BigQueryReader( project_id=_PROJECT, dataset_id=_DATASET, diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py index 9b6c056d6c8adfa50b95aefb8e9740631327a572..4f2ecbcb170b56ab276ec37bbaa3db2485d58f49 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase): def testSetBlockCache(self): cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) - with self.test_session() as sess: + with self.cached_session() as sess: gcs_config_ops.configure_gcs(sess, block_cache=cfg) def testConfigureGcsHook(self): @@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase): 'type': 'authorized_user'} hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) hook.begin() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None hook.after_create_session(sess, None) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1ab150d74ac00c5f9acf3c9399880708b2f62b1e..1056894f18f1ec19a598dfbd1161d7f9bea7e94f 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver): def get_master(self): return self.master() + def get_job_name(self): + if self._shouldResolve(): + return self._job_name + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 0b79f718d4823a987e02804f59a432ee46d0ada3..789dab81ed848851f6597ec8dfae3d3455e84f86 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -1,6 +1,10 @@ TensorFlow CMake build ====================== +CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all +platforms. For details, see the +[TensorFlow install guide](https://www.tensorflow.org/install/). + This directory contains CMake files for building TensorFlow on Microsoft Windows. [CMake](https://cmake.org) is a cross-platform tool that can generate build scripts for multiple build systems, including Microsoft diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index ad2af01bc002555ce48f8b9bfb7d8d724a1a7dc8..1a147e9c8e5a9fee17a81e37c9babe3c9ec0290b 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== include (ExternalProject) +include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) @@ -35,7 +36,7 @@ if(WIN32) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a) endif() set(png_HEADERS diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py index 9b4bf6271009161c4c449cd9c3cdab9fba90aa59..3e25079e02eb22cb8796cce1a49a3041bed58415 100644 --- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase): multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) - with self.test_session() as session: + with self.cached_session() as session: projected_multipliers1 = session.run( external_regret_optimizer._project_multipliers_wrt_euclidean_norm( multipliers1, 1.0)) @@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase): ] multipliers = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(multipliers) < len(expected_multipliers): multipliers.append(session.run(optimizer.lagrange_multipliers)) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py index 34c4543dca97e12c8335e4c90b849820edaefa81..df0eced631718995fc3219657db6813da7375cba 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase): matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: eigenvector1 = session.run( swap_regret_optimizer._maximal_eigenvector_power_method( standard_ops.constant(matrix1))) @@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( matrix)) @@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase): expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) - with self.test_session() as session: + with self.cached_session() as session: projected_matrix = session.run( standard_ops.exp( swap_regret_optimizer. @@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) @@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase): ] matrices = [] - with self.test_session() as session: + with self.cached_session() as session: session.run(standard_ops.global_variables_initializer()) while len(matrices) < len(expected_matrices): matrices.append(session.run(optimizer.stochastic_matrix)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 8cfe14205927bf7763cf36fa31012ab10fce995c..556d73184022dcc23add29114d717ab17302f8d4 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -61,7 +61,7 @@ class CrfTest(test.TestCase): for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, inputs_list, tag_indices_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), @@ -96,7 +96,7 @@ class CrfTest(test.TestCase): ] for sequence_lengths, inputs, tag_bitmap in zip( sequence_lengths_list, inputs_list, tag_bitmap_list): - with self.test_session() as sess: + with self.cached_session() as sess: sequence_score = crf.crf_multitag_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_bitmap=array_ops.expand_dims(tag_bitmap, 0), @@ -124,7 +124,7 @@ class CrfTest(test.TestCase): for dtype in (np.int32, np.int64): tag_indices = np.array([1, 2, 1, 0], dtype=dtype) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: unary_score = crf.crf_unary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -140,7 +140,7 @@ class CrfTest(test.TestCase): transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: + with self.cached_session() as sess: binary_score = crf.crf_binary_score( tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), @@ -176,7 +176,7 @@ class CrfTest(test.TestCase): tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. @@ -206,7 +206,7 @@ class CrfTest(test.TestCase): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], @@ -226,7 +226,7 @@ class CrfTest(test.TestCase): sequence_lengths = np.array(3, dtype=np.int32) num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_log_likelihoods = [] # Make sure all probabilities sum to 1. @@ -254,7 +254,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -310,7 +310,7 @@ class CrfTest(test.TestCase): num_words = inputs.shape[0] num_tags = inputs.shape[1] - with self.test_session() as sess: + with self.cached_session() as sess: all_sequence_scores = [] all_sequences = [] @@ -351,7 +351,7 @@ class CrfTest(test.TestCase): """ Test that crf_decode works when sequence_length contains one or more zeros. """ - with self.test_session() as sess: + with self.cached_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5e6c1520a2fc1c21678625c9d4aae04164b198f6..baec238c62e5cd375e3e8d46039e8e5b21269a6f 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@CheckpointInputPipelineHook @@CsvDataset @@LMDBDataset +@@Optional @@RandomDataset @@Reducer @@SqlDataset @@ -38,7 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@copy_to_device @@dense_to_sparse_batch @@enumerate_dataset - +@@get_next_as_optional @@get_single_element @@group_by_reducer @@group_by_window @@ -46,7 +47,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator - @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -107,6 +107,8 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter +from tensorflow.python.data.ops.iterator_ops import get_next_as_optional +from tensorflow.python.data.ops.optional_ops import Optional # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc index e36c9c0634235022362b59a6699b4d550d6d0eee..c19a609780d5e2ac3175404bf8d3bcaf03b7dd01 100644 --- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), AssertNextDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 0ba905b92e2d9a14128b540028687955bd96f2f0..21ec50fb6b8a5bbff6d7fe85fb949e127ceab8ed 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { +namespace data { namespace { class CSVDatasetOp : public DatasetOpKernel { @@ -48,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, errors::InvalidArgument( "There should only be 1 default per field but field ", i, @@ -851,4 +855,5 @@ class CSVDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index ccf7ec1f842f5a1ad9b304c904f046ad49ed1757..a5321620bf666a8568fc073b0edad17440b30133 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -276,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc index 4718c1c8b9d77b5dbac2a8caf11d9a0604af94c2..c3cb45dbf7bf4d3fed828a7d108472db24f7e33a 100644 --- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc +++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace data { namespace { class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { @@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU), IdentityIndexedDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index db24e608463224f05159b57eb721718afd7cbb20..beec344534078dc3a1069dd2079684340b4f7919 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -137,5 +137,5 @@ REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU), IgnoreErrorsDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc index c69564a31bbc3a07ff56e0da564e7e1b8323f464..ced8ab0d608cddc51ede38490bb8d9aecbe7da92 100644 --- a/tensorflow/contrib/data/kernels/indexed_dataset.cc +++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { - +namespace data { namespace { Status VerifyTypesMatch(const DataTypeVector& expected, @@ -367,6 +367,7 @@ REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU), MaterializeDatasetOp); REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU), IndexedDatasetGet); -} // namespace +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h index 6149de888cc0a966ead48c790074d63ca028f1e8..7aa2d3fdbc2db768b75bbdcaad7d71b29a3ca4c9 100644 --- a/tensorflow/contrib/data/kernels/indexed_dataset.h +++ b/tensorflow/contrib/data/kernels/indexed_dataset.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { // TODO(saeta): Urgh, this is ugly. class MaterializedIndexedDataset { @@ -112,6 +113,7 @@ Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, Tensor* tensor); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc index 80f39992fbb1ff1395c308f00a5d02903d368891..d233c1f8ec9639e68b979971d5d46621c6817b4b 100644 --- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "lmdb.h" // NOLINT(build/include) namespace tensorflow { +namespace data { namespace { class LMDBDatasetOp : public DatasetOpKernel { @@ -212,4 +213,5 @@ class LMDBDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 725f8933c94cb42339556f63982d69d1bf0bb504..078de717e02bfdc9aa7e9842e27c660f7d174ce3 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace data { namespace { struct BufferElement { @@ -1114,5 +1115,6 @@ REGISTER_KERNEL_BUILDER( Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), MultiDeviceIteratorFromStringHandleOp); -} // anonymous namespace +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index ab584504a05369105d080df73750974af9fc70bb..30fa97a6363ce130b65edf5a4db2d5955b17beab 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { +namespace data { namespace { class ThreadPoolResource : public ResourceBase { @@ -214,4 +215,5 @@ REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), ThreadPoolDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 6fbf5d2ebb598132a7e8433608e67436a172b615..57fc5697a44a2373b0aa97a1fb89917817827cac 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -219,5 +219,5 @@ REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU), UniqueDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index ae104d55bd813fdbc9829ccbc274612a112c8e1d..ad410e17feb9de825aa3af07d4269161121a621a 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset") TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); // `record_defaults` must be lists of scalars for (size_t i = 8; i < c->num_inputs(); ++i) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + shape_inference::ShapeHandle v; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { + return errors::InvalidArgument( + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); + } } return shape_inference::ScalarShape(c); }); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 34f594f74194596292c7004295e6ecc2e4e125ec..ba202839b2f83b61256686b955c51bc0ae2cdace 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -72,12 +72,13 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) @@ -276,25 +277,13 @@ py_test( "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", - ], -) - -py_test( - name = "optimize_dataset_op_test", - size = "small", - srcs = ["optimize_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python:session", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 9d8e955245e0e3bc9c7635b801136c22bfc83488..8e368bf2bc5060e1655dd24b1d285b0ee80e094d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for start in range(0, len(components), 4): @@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Initialize with an input tensor of incompatible rank. sess.run(init_op, feed_dict={input_tensor: [[1]]}) with self.assertRaisesRegexp(errors.InvalidArgumentError, @@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i,) * 3, sess.run(op)) @@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) @@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): st_row = sess.run(next_element) self.assertEqual([i], st_row.indices) @@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): dense_elem, st_row = sess.run(next_element) self.assertEqual(i, dense_elem) @@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i,),) * 3, sess.run(op)) @@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() op = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) @@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) @@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = data.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # Mismatch in the 0th dimension. sess.run( iterator.initializer, @@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_batch_size in [1, 3, 7, 10]: sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) num_batches = 7 // test_batch_size @@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) @parameterized.named_parameters( - ("default", None, None), - ("sequential_calls", 1, None), - ("parallel_calls", 2, None), - ("parallel_batches", None, 10), + ("Default", None, None), + ("SequentialCalls", 1, None), + ("ParallelCalls", 2, None), + ("ParallelBatches", None, 10), ) def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" @@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: # Batch of a finite input, where the batch_size divides the # total number of elements. sess.run(init_op, feed_dict={count: 28, batch_size: 14}) @@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) @parameterized.named_parameters( - ("even", False), - ("uneven", True), + ("Even", False), + ("Uneven", True), ) def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( @@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) if not drop_remainder: @@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_one_shot_iterator()) self.assertEqual([None, 1], iterator.output_shapes.as_list()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element)) self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element)) self.assertAllEqual([[64], [81]], sess.run(next_element)) @@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): elements = [] for _ in range(100): elements.append(iterator.get_next()) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(4): got = sess.run(elements) got.sort(key=lambda x: x[0]) @@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(2): actual = sess.run(get_next) @@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) .make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) @@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp(errors.InvalidArgumentError, "number of elements does not match"): @@ -659,11 +659,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(3): sess.run(get_next) - @parameterized.parameters(0, 5, 10, 90, 95, 99) + @parameterized.named_parameters( + ("1", 0), + ("2", 5), + ("3", 10), + ("4", 90), + ("5", 95), + ("6", 99), + ) def testMapAndBatchOutOfRangeError(self, threshold): def raising_py_fn(i): @@ -679,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): batch_size=10)).make_one_shot_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(threshold // 10): self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next)) if threshold % 10 != 0: @@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (False, dtypes.bool), - (-42, dtypes.int8), - (-42, dtypes.int16), - (-42, dtypes.int32), - (-42, dtypes.int64), - (42, dtypes.uint8), - (42, dtypes.uint16), - (42.0, dtypes.float16), - (42.0, dtypes.float32), - (42.0, dtypes.float64), - (b"hello", dtypes.string), + @parameterized.named_parameters( + ("1", False, dtypes.bool), + ("2", -42, dtypes.int8), + ("3", -42, dtypes.int16), + ("4", -42, dtypes.int32), + ("5", -42, dtypes.int64), + ("6", 42, dtypes.uint8), + ("7", 42, dtypes.uint16), + ("8", 42.0, dtypes.float16), + ("9", 42.0, dtypes.float32), + ("10", 42.0, dtypes.float64), + ("11", b"hello", dtypes.string), ) def testMapAndBatchTypes(self, element, dtype): def gen(): @@ -711,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(10): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) @@ -777,7 +784,7 @@ class RestructuredDatasetTest(test.TestCase): iterator = result.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(5): sess.run(get_next) @@ -901,7 +908,7 @@ class RestructuredDatasetTest(test.TestCase): .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 2022c1f2bdd09cdf43a993b3666335ce468a40ba..48971f2ccc4317d2bf591ae1e07cd6d5baf7b965 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for expected in values: got = sess.run(get_next) self.assertEqual(got, expected) @@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase): self.assertIs(None, dataset.output_shapes[1].ndims) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual([0] * (2**i), x) self.assertAllEqual(np.array(1, ndmin=i), y) @@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase): (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply( grouping.group_by_reducer(lambda x, y: np.int64(0), reducer)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: x, y = sess.run(get_next) self.assertAllEqual(x, np.asarray([x for x in range(10)])) self.assertEqual(y, 45) @@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # The input is infinite, so this test demonstrates that: # 1. We produce output without having to consume the entire input, @@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) @@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaisesRegexp( errors.InvalidArgumentError, @@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) counts = [] with self.assertRaises(errors.OutOfRangeError): @@ -376,7 +376,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) which_bucket, bucketed_values = sess.run(get_next) @@ -411,7 +411,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches (one containing even values, one containing odds) @@ -482,7 +482,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Get two minibatches ([0, 2, ...] and [64, 66, ...]) @@ -515,7 +515,7 @@ class BucketTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) with self.assertRaises(errors.OutOfRangeError): batches = 0 @@ -531,6 +531,45 @@ class BucketTest(test.TestCase): self.assertEqual(batches, 15) +def _element_length_fn(x, y=None): + del y + return array_ops.shape(x)[0] + + +def _to_sparse_tensor(record): + return sparse_tensor.SparseTensor(**record) + + +def _format_record(array, sparse): + if sparse: + return { + "values": array, + "indices": [[i] for i in range(len(array))], + "dense_shape": (len(array),) + } + return array + + +def _get_record_type(sparse): + if sparse: + return { + "values": dtypes.int64, + "indices": dtypes.int64, + "dense_shape": dtypes.int64 + } + return dtypes.int32 + + +def _get_record_shape(sparse): + if sparse: + return { + "values": tensor_shape.TensorShape([None,]), + "indices": tensor_shape.TensorShape([None, 1]), + "dense_shape": tensor_shape.TensorShape([1,]) + } + return tensor_shape.TensorShape([None]) + + class BucketBySequenceLength(test.TestCase): def testBucket(self): @@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase): batch_sizes = [10, 8, 4, 2] lengths = [8, 13, 25, 35] - def element_gen(): - # Produce 1 batch for each bucket - elements = [] - for batch_size, length in zip(batch_sizes, lengths): - for _ in range(batch_size): - elements.append([1] * length) - random.shuffle(elements) - for el in elements: - yield (el,) - - element_len = lambda el: array_ops.shape(el)[0] - dataset = dataset_ops.Dataset.from_generator( - element_gen, (dtypes.int64,), ([None],)).apply( - grouping.bucket_by_sequence_length( - element_len, boundaries, batch_sizes)) - batch, = dataset.make_one_shot_iterator().get_next() - - with self.test_session() as sess: - batches = [] - for _ in range(4): - batches.append(sess.run(batch)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(batch) - batch_sizes_val = [] - lengths_val = [] - for batch in batches: - batch_size = batch.shape[0] - length = batch.shape[1] - batch_sizes_val.append(batch_size) - lengths_val.append(length) - self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) - self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) - self.assertEqual(sorted(lengths), sorted(lengths_val)) + def build_dataset(sparse): + def _generator(): + # Produce 1 batch for each bucket + elements = [] + for batch_size, length in zip(batch_sizes, lengths): + record_len = length - 1 + for _ in range(batch_size): + elements.append([1] * record_len) + record_len = length + random.shuffle(elements) + for el in elements: + yield (_format_record(el, sparse),) + dataset = dataset_ops.Dataset.from_generator( + _generator, + (_get_record_type(sparse),), + (_get_record_shape(sparse),)) + if sparse: + dataset = dataset.map(lambda x: (_to_sparse_tensor(x),)) + return dataset + + def _test_bucket_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply( + grouping.bucket_by_sequence_length( + _element_length_fn, + boundaries, + batch_sizes, + no_padding=no_padding)) + batch, = dataset.make_one_shot_iterator().get_next() + + with self.cached_session() as sess: + batches = [] + for _ in range(4): + batches.append(sess.run(batch)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(batch) + batch_sizes_val = [] + lengths_val = [] + for batch in batches: + shape = batch.dense_shape if no_padding else batch.shape + batch_size = shape[0] + length = shape[1] + batch_sizes_val.append(batch_size) + lengths_val.append(length) + sum_check = batch.values.sum() if no_padding else batch.sum() + self.assertEqual(sum_check, batch_size * length - 1) + self.assertEqual(sum(batch_sizes_val), sum(batch_sizes)) + self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val)) + self.assertEqual(sorted(lengths), sorted(lengths_val)) + + for no_padding in (True, False): + _test_bucket_by_padding(no_padding) def testPadToBoundary(self): @@ -600,7 +658,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(3): batches.append(sess.run(batch)) @@ -637,7 +695,7 @@ class BucketBySequenceLength(test.TestCase): pad_to_bucket_boundary=True)) batch, = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: batches = [] for _ in range(5): batches.append(sess.run(batch)) @@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase): def testTupleElements(self): - def elements_gen(): - text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] - label = [1, 2, 1, 2] - for x, y in zip(text, label): - yield (x, y) - - def element_length_fn(x, y): - del y - return array_ops.shape(x)[0] - - dataset = dataset_ops.Dataset.from_generator( - generator=elements_gen, - output_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([])), - output_types=(dtypes.int32, dtypes.int32)) + def build_dataset(sparse): + def _generator(): + text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]] + label = [1, 2, 1, 2] + for x, y in zip(text, label): + yield (_format_record(x, sparse), y) + dataset = dataset_ops.Dataset.from_generator( + generator=_generator, + output_types=(_get_record_type(sparse), dtypes.int32), + output_shapes=(_get_record_shape(sparse), + tensor_shape.TensorShape([]))) + if sparse: + dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y)) + return dataset + + def _test_tuple_elements_by_padding(no_padding): + dataset = build_dataset(sparse=no_padding) + dataset = dataset.apply(grouping.bucket_by_sequence_length( + element_length_func=_element_length_fn, + bucket_batch_sizes=[2, 2, 2], + bucket_boundaries=[0, 8], + no_padding=no_padding)) + shapes = dataset.output_shapes + self.assertEqual([None, None], shapes[0].as_list()) + self.assertEqual([None], shapes[1].as_list()) + + for no_padding in (True, False): + _test_tuple_elements_by_padding(no_padding) + + def testBucketSparse(self): + """Tests bucketing of sparse tensors (case where `no_padding` == True). + + Test runs on following dataset: + [ + [0], + [0, 1], + [0, 1, 2] + ... + [0, ..., max_len - 1] + ] + Sequences are bucketed by length and batched with + `batch_size` < `bucket_size`. + """ + + min_len = 0 + max_len = 100 + batch_size = 7 + bucket_size = 10 + + def _build_dataset(): + input_data = [range(i+1) for i in range(min_len, max_len)] + def generator_fn(): + for record in input_data: + yield _format_record(record, sparse=True) + dataset = dataset_ops.Dataset.from_generator( + generator=generator_fn, + output_types=_get_record_type(sparse=True)) + dataset = dataset.map(_to_sparse_tensor) + return dataset + + def _compute_expected_batches(): + """Computes expected batch outputs and stores in a set.""" + all_expected_sparse_tensors = set() + for bucket_start_len in range(min_len, max_len, bucket_size): + for batch_offset in range(0, bucket_size, batch_size): + batch_start_len = bucket_start_len + batch_offset + batch_end_len = min(batch_start_len + batch_size, + bucket_start_len + bucket_size) + expected_indices = [] + expected_values = [] + for length in range(batch_start_len, batch_end_len): + for val in range(length + 1): + expected_indices.append((length - batch_start_len, val)) + expected_values.append(val) + expected_sprs_tensor = (tuple(expected_indices), + tuple(expected_values)) + all_expected_sparse_tensors.add(expected_sprs_tensor) + return all_expected_sparse_tensors + + def _compute_batches(dataset): + """Computes actual batch outputs of dataset and stores in a set.""" + batch = dataset.make_one_shot_iterator().get_next() + all_sparse_tensors = set() + with self.cached_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + while True: + output = sess.run(batch) + sprs_tensor = (tuple([tuple(idx) for idx in output.indices]), + tuple(output.values)) + all_sparse_tensors.add(sprs_tensor) + return all_sparse_tensors + + dataset = _build_dataset() + boundaries = range(min_len + bucket_size + 1, max_len, bucket_size) dataset = dataset.apply(grouping.bucket_by_sequence_length( - element_length_func=element_length_fn, - bucket_batch_sizes=[2, 2, 2], - bucket_boundaries=[0, 8])) - shapes = dataset.output_shapes - self.assertEqual([None, None], shapes[0].as_list()) - self.assertEqual([None], shapes[1].as_list()) + _element_length_fn, + boundaries, + [batch_size] * (len(boundaries) + 1), + no_padding=True)) + batches = _compute_batches(dataset) + expected_batches = _compute_expected_batches() + self.assertEqual(batches, expected_batches) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 63bffd023f0e2672f41d36e27e31c9a9b26be77c..f8e74e4583df5b4e2cdd73c94361486680cee3f4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class CsvDatasetOpTest(test.TestCase): - def _assert_datasets_equal(self, g, ds1, ds2): + def _get_next(self, dataset): + # Returns a no argument function whose result is fed to self.evaluate to + # yield the next element + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + get_next = it.get_next() + return lambda: get_next + + def _assert_datasets_equal(self, ds1, ds2): assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' '%s') % (ds1.output_shapes, ds2.output_shapes) assert ds1.output_types == ds2.output_types assert ds1.output_classes == ds2.output_classes - next1 = ds1.make_one_shot_iterator().get_next() - next2 = ds2.make_one_shot_iterator().get_next() - with self.session(graph=g) as sess: - # Run through datasets and check that outputs match, or errors match. - while True: - try: - op1 = sess.run(next1) - except (errors.OutOfRangeError, ValueError) as e: - # If op1 throws an exception, check that op2 throws same exception. - with self.assertRaises(type(e)): - sess.run(next2) - break - op2 = sess.run(next2) - self.assertAllEqual(op1, op2) + next1 = self._get_next(ds1) + next2 = self._get_next(ds2) + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = self.evaluate(next1()) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + self.assertAllEqual(op1, op2) def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase): def _test_by_comparison(self, inputs, **kwargs): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" - with ops.Graph().as_default() as g: - dataset_actual, dataset_expected = self._make_test_datasets( - inputs, **kwargs) - self._assert_datasets_equal(g, dataset_actual, dataset_expected) + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(dataset_actual, dataset_expected) def _verify_output_or_err(self, - sess, dataset, expected_output=None, expected_err_re=None): - nxt = dataset.make_one_shot_iterator().get_next() if expected_err_re is None: # Verify that output is expected, without errors + nxt = self._get_next(dataset) expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] for value in expected_output: - op = sess.run(nxt) + op = self.evaluate(nxt()) self.assertAllEqual(op, value) with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) + self.evaluate(nxt()) else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): + nxt = self._get_next(dataset) while True: try: - sess.run(nxt) + self.evaluate(nxt()) except errors.OutOfRangeError: break @@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase): # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(sess, dataset, expected_output, - expected_err_re) + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 @@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): record_defaults = [['']] * 3 @@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) - with ops.Graph().as_default() as g: - self._assert_datasets_equal(g, - ds_actual.repeat(5).prefetch(1), - ds_expected.repeat(5).prefetch(1)) + self._assert_datasets_equal( + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields @@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase): ]] file_path = self._setup_files(data) - with ops.Graph().as_default() as g: - ds = readers.make_csv_dataset( - file_path, batch_size=1, shuffle=False, num_epochs=1) - next_batch = ds.make_one_shot_iterator().get_next() + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + nxt = self._get_next(ds) - with self.session(graph=g) as sess: - result = list(sess.run(next_batch).values()) + result = list(self.evaluate(nxt()).values()) self.assertEqual(result, sorted(result)) @@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase): compression_type='ZLIB', record_defaults=record_defaults) + def testCsvDataset_withScalarDefaults(self): + record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_with2DDefaults(self): + record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + + if context.executing_eagerly(): + err_spec = errors.InvalidArgumentError, ( + 'Each record default should be at ' + 'most rank 1.') + else: + err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' + + with self.assertRaisesWithPredicateMatch(*err_spec): + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 9020a499c4a5c35202a6f776d8795186b9c86e20..eb110324d12b47fc36bc0927ad8dc94e6892dc33 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(100): for i in range(10): @@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: freqs = np.zeros([num_datasets]) for _ in range(num_samples): freqs[sess.run(next_element)] += 1 @@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in choice_array: self.assertEqual(words[i], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index e6883d53e02c0f96d966a52abfe2f9b4118f2e12..f3968cdc15a5a34af0946c5c447ce35cdfa3e00d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): lambda x: (x * x, make_sparse(x))).take(take_t) element = get_single_element.get_single_element(dataset) - with self.test_session() as sess: + with self.cached_session() as sess: if error is None: dense_val, sparse_val = sess.run( element, feed_dict={ @@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase): dataset = dataset_ops.Dataset.range(stop_t) element = get_single_element.reduce_dataset(dataset, sum_reducer) - with self.test_session() as sess: + with self.cached_session() as sess: value = sess.run(element, feed_dict={stop_t: stop}) self.assertEqual(stop * (stop - 1) / 2, value) diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index db2ab815eeebb77c159ca8c7d0d9920f2bdcdabd..9c508d686dd44d04444a34c703ab54f3b97eeced 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -44,14 +44,14 @@ class IndexedDatasetOpsTest(test.TestCase): get_op = gen_dataset_ops.indexed_dataset_get( handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialize) self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) def testIdentityIndexedDataset(self): ds = indexed_dataset_ops.IdentityIndexedDataset(16) materialized = ds.materialize() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(materialized.initializer) placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) for i in range(16): @@ -66,7 +66,7 @@ class IndexedDatasetOpsTest(test.TestCase): ds = indexed_dataset_ops.IdentityIndexedDataset(16) itr = ds.make_initializable_iterator() n = itr.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(itr.initializer) for i in range(16): output = sess.run(n) diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 7a3215f6ccfa807e8930ac8561587e474da61195..b9e74dfddb1b238ab75928af17d545d2b6a3c033 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedRagged(self): # Tests a sequence with wildly different elements per iterator. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): Args: sloppy: Whether to be sloppy or not. """ - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) def _testEmptyInput(self, sloppy=False): - with self.test_session() as sess: + with self.cached_session() as sess: # Empty input. self._clear_coordination_events() sess.run( @@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) def testBlockLengthWithContentionSloppy(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() done_first_event = False sess.run( @@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: + with self.cached_session() as sess: output_values = [] for _ in range(30): output_values.append(sess.run(iterator.get_next())) @@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for i in range(10): for j in range(2): @@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(get_next) def testErrorsInOutputFn(self): - with self.test_session() as sess: + with self.cached_session() as sess: self._clear_coordination_events() sess.run( self.init_op, @@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.init_op, feed_dict={ @@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): next_element = iterator.get_next() results = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): elements = [] sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 7bc582ebaa50c7418e7624a1a389f002f2cea395..1cc5ddc9a2e1eff4473c19bc397d656e7e0b90ed 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -51,7 +51,7 @@ class LMDBDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for _ in range(num_repeats): # Dataset is repeated. for i in range(10): # 10 records. diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index dc9d56dd53cc077c14eda58a22d7449c05bddec1..e8519381d69427f4c9a3ef5cefa527c368251f2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for x in [1., 2., 3., 5.]: self.assertEqual(x, sess.run(get_next)) @@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: # All of the files are present. sess.run(init_op) for filename in filenames: @@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark): end = time.time() chained_deltas.append(end - start) - fused_dataset = dataset = dataset.apply( + fused_dataset = dataset.apply( batching.map_and_batch( math_ops.matmul, num_parallel_calls=num_calls, diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 73cde40305a676e114a722bf8b4702e152346c8b..83b723710ca1d37a8d2b1e297321b59dcaa17ba6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,10 +28,11 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test - class MapDefunTest(test.TestCase): def testMapDefunSimple(self): @@ -130,6 +134,146 @@ class MapDefunTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result) + def testMapDefunCancelledCorrectly(self): + + @function.Defun(dtypes.int64) + def defun(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + c = array_ops.tile( + array_ops.expand_dims( + constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), + [100, 1]) + map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(map_defun_op) + + def testMapDefunWithUnspecifiedOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + res = x * 2 + 3 + return (res, res + 1, res + 2) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], + [dtypes.int32, dtypes.int32, dtypes.int32], + [None, (None,), (2,)]) + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) + self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) + self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + + def testMapDefunWithDifferentOutputShapeEachRun(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + elems = array_ops.placeholder(dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] + with session.Session() as sess: + self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) + self.assertAllEqual( + sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + + def testMapDefunWithWrongOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunWithInvalidInput(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + + c = constant_op.constant(2) + with self.assertRaises(ValueError): + # Fails at graph construction time for inputs with known shapes. + r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] + p = array_ops.placeholder(dtypes.int32) + r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] + with session.Session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(r, feed_dict={p: 0}) + + def _assert_op_cancelled(self, sess, map_defun_op): + with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"): + sess.run(map_defun_op) + + def testMapDefunWithParentCancellation(self): + # Checks that a cancellation of the parent graph is threaded through to + # MapDefunOp correctly. + @function.Defun(dtypes.int32) + def simple_fn(x): + del x + queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ()) + # Blocking + return queue.dequeue_many(5) + + c = constant_op.constant([1, 2, 3, 4, 5]) + map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0] + + with self.test_session() as sess: + thread = self.checkedThread( + self._assert_op_cancelled, args=(sess, map_defun_op)) + thread.start() + time.sleep(0.1) + sess.close() + thread.join() + + +class MapDefunBenchmark(test.Benchmark): + + def _run(self, op, name=None, num_iters=3000): + with session.Session() as sess: + # Warm up the session + for _ in range(5): + sess.run(op) + start = time.time() + for _ in range(num_iters): + sess.run(op) + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmarkDefunVsMapFn(self): + """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" + + @function.Defun(dtypes.int32) + def defun(x): + return array_ops.identity(x) + + def map_fn(x): + return array_ops.identity(x) + + base = math_ops.range(100) + for input_size in [10, 100, 1000, 10000]: + num_iters = 100000 // input_size + map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) + map_fn_op = functional_ops.map_fn(map_fn, base) + + self._run( + map_defun_op, + "benchmarkMapDefun_size_%d" % input_size, + num_iters=num_iters) + self._run( + map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index b299e0736fb29d0936680e5905172b0fa95ac586..7e9ea68047a076d368cf98960f4754b29abb074e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -6,6 +6,34 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") +py_test( + name = "assert_next_dataset_op_test", + size = "medium", + srcs = ["assert_next_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "map_vectorization_test", size = "small", @@ -46,16 +74,34 @@ py_test( ) py_test( - name = "latency_all_edges_test", + name = "model_dataset_op_test", + size = "medium", + srcs = ["model_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = [ + "optonly", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_op_test", size = "small", - srcs = ["latency_all_edges_test.py"], + srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7b50b902cf5965bfdb586c5c9fce68ba5d9cd6 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class AssertNextDatasetTest(test.TestCase): + + def testAssertNext(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + + def testAssertNextInvalid(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted Whoops transformation at offset 0 but encountered " + "Map transformation instead."): + sess.run(get_next) + + def testAssertNextShort(self): + dataset = dataset_ops.Dataset.from_tensors(0).apply( + optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Asserted next 2 transformations but encountered only 1."): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py index 1850b6921af0aae8d26fbdfd165fd0e087134e6d..db380c02a9191bec53d5e32565d47a52cbdd44b1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py @@ -40,7 +40,7 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): get_next = iterator.get_next() summary_t = stats_aggregator.get_summary() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertEqual(1 * 1, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 586b4bee5fcb1d8de44e8bc5e78cc21e15870a5c..dde115925ee484edb88ad81b21595c3d668be84c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for i, fun1 in enumerate(functions): for j, fun2 in enumerate(functions): tests.append(( - "test_{}_{}".format(i, j), + "Test{}{}".format(i, j), [fun1, fun2], )) for k, fun3 in enumerate(functions): tests.append(( - "test_{}_{}_{}".format(i, j, k), + "Test{}{}{}".format(i, j, k), [fun1, fun2, fun3], )) swap = lambda x, n: (n, x) tests.append(( - "swap1", + "Swap1", [lambda x: (x, 42), swap], )) tests.append(( - "swap2", + "Swap2", [lambda x: (x, 42), swap, swap], )) return tuple(tests) @@ -74,7 +74,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(5): result = sess.run(get_next) r = x @@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for x, fun in enumerate(functions): for y, predicate in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + tests.append(("Mixed{}{}".format(x, y), fun, predicate)) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), lambda x, y: constant_op.constant(True))) tests.append( - ("multiTwo", lambda x: (x, 2), + ("Multi2", lambda x: (x, 2), lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) return tuple(tests) @@ -131,7 +131,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): def _testMapAndFilter(self, dataset, function, predicate): iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for x in range(10): r = function(x) if isinstance(r, tuple): @@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): identity = lambda x: x for x, predicate_1 in enumerate(filters): for y, predicate_2 in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), identity, + tests.append(("Mixed{}{}".format(x, y), identity, [predicate_1, predicate_2])) for z, predicate_3 in enumerate(filters): - tests.append(("mixed_{}_{}_{}".format(x, y, z), identity, + tests.append(("Mixed{}{}{}".format(x, y, z), identity, [predicate_1, predicate_2, predicate_3])) take_all_multiple = lambda x, y: constant_op.constant(True) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), [take_all_multiple, take_all_multiple])) - tests.append(("multiTwo", lambda x: (x, 2), [ + tests.append(("Multi2", lambda x: (x, 2), [ take_all_multiple, lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) ])) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a87d3e90550da8485b4f9acd941c836d7b62951 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -0,0 +1,177 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ModelDatasetTest(test.TestCase): + + def testModelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(100): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelMap(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul, num_parallel_calls=56) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelMapAndBatch(self): + batch_size = 16 + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.apply( + batching.map_and_batch( + math_ops.matmul, num_parallel_calls=28, batch_size=batch_size)) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(10): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelParallelInterleave(self): + k = 1024 * 1024 + dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k), + np.random.rand(4 * k, + 1))).repeat() + dataset = dataset.map(math_ops.matmul) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=56, num_parallel_calls=56) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next.op) + for _ in range(1000): + start = time.time() + sess.run(get_next.op) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + def testModelNested(self): + k = 1024 * 1024 + a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1)) + b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1)) + c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1)) + dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat() + + def f1(a, b, c): + x, y = a + return math_ops.matmul(x, y), b, c + + def f2(a, b, c): + x, y = b + return a, math_ops.matmul(x, y), c + + def f3(a, b, c): + x, y = c + return a, b, math_ops.matmul(x, y) + + dataset = dataset.map(f1, num_parallel_calls=32) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f2, num_parallel_calls=16) + dataset = dataset_ops.Dataset.range(1).repeat().interleave( + lambda _: dataset, cycle_length=2) + + dataset = dataset.map(f3, num_parallel_calls=10) + iterator = dataset.apply(optimization.model()).make_one_shot_iterator() + get_next = iterator.get_next() + + deltas = [] + with self.test_session() as sess: + for _ in range(5): + sess.run(get_next) + for _ in range(100): + start = time.time() + sess.run(get_next) + end = time.time() + deltas.append(end - start) + + print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" % + (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas), + np.max(deltas))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py similarity index 75% rename from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index 089717156c545a0ea9262c4380ab2c0fd088e209..909da5aee0ad8bce0b5b18facbcbb684dd334abf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import optimization @@ -29,41 +28,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): - - def testAssertSuffix(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - - def testAssertSuffixInvalid(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted Whoops transformation at offset 0 but encountered " - "Map transformation instead."): - sess.run(get_next) - - def testAssertSuffixShort(self): - dataset = dataset_ops.Dataset.from_tensors(0).apply( - optimization.assert_next(["Map", "Whoops"])).map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Asserted next 2 transformations but encountered only 1."): - sess.run(get_next) +class OptimizeDatasetTest(test.TestCase): def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index f6c4a984b8608b408bc1b1bb4a712ef1c3792696..c4623bca73228b76802ed40b18eb49662f6f7d34 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -80,7 +80,7 @@ class ParseExampleTest(test.TestCase): expected_values=None, expected_err=None): - with self.test_session() as sess: + with self.cached_session() as sess: if expected_err: with self.assertRaisesWithPredicateMatch(expected_err[0], expected_err[1]): diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 361fe0dd39bb3f855c3b0b11281a9909fd601232..0166ba0d44ef473ac54ee4f67078c1a51fddacf3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): destroy_op = resource_variable_ops.destroy_resource_op( buffer_resource_handle, ignore_lookup_error=True) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([b"a"], sess.run(prefetch_op)) self.assertEqual([b"b"], sess.run(prefetch_op)) self.assertEqual([b"c"], sess.run(prefetch_op)) @@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase): self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([0, 1, 2, 3], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase): iterator = back_to_cpu_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(10): self.assertEqual(i, sess.run(next_element)) @@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for i in range(5): self.assertEqual(i, sess.run(next_element)) @@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase): elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() - with self.test_session() as sess: + with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 592642da0cfd84e50cb20d9b2e534411faf927e8..db8fe6aa1b29c5c3f872e580491d978f03360fe4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase): self.assertEqual([tensor_shape.TensorShape([])] * 3, [t.shape for t in get_next[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next)) self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next)) @@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase): .make_one_shot_iterator()) negative_get_next = negative_iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(3, sess.run(get_next)) self.assertEqual(3 + 4, sess.run(get_next)) self.assertEqual(3 + 2 * 4, sess.run(get_next)) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index fd00cdc5c61cb0a6bbee87963ed4097a236507d3..ed75b27a4493f9ebb9db34c4e656d394236ae08e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -116,7 +116,7 @@ class ReadBatchFeaturesTest( init_op = iterator.initializer next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) for file_batch, _, _, _, record_batch, _ in self._next_expected_batch( range(self._num_files), 2, 10): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index c5cfddb72b56a1bcffc80c0dd34994def3ee45cd..16b1441baab925ed5b6eee4193203690d1552f03 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): class_func=lambda c, _: c, seed=27)).make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] while len(returned) < 4000: returned.append(sess.run(get_next)) @@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: @@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase): get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: returned = [] with self.assertRaises(errors.OutOfRangeError): while True: diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 42cada0b97bcd9ab755297e8b1f0667766f7999e..dde678bd544fc2eaba36f91491fc64e4c7910756 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase): start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase): make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), (10, 2, 10), (10, -1, 10), @@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(5): (longer_vector_val, larger_rank_val), _ = sess.run(next_element) self.assertAllEqual([0] * (2**i), longer_vector_val) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 4881f63ab96cb4797e6e071bf3e310c73bc85f3d..aa89674c6e74686feb5b3a9331eb1839a241c4be 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -210,6 +210,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py index ac3892fe81a1c0d325ddc5f501c2caed4b53f5d5..243f6405a13b96ce2bb1a2c924ce6b0d742ca8c0 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base @@ -27,42 +28,38 @@ from tensorflow.python.platform import test class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): - def _build_iterator_graph(self, input_values, cycle_length, block_length): + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): repeat_count = 2 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( repeat_count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) + cycle_length, block_length, num_parallel_calls) - def testSerializationCore(self): + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): input_values = np.array([4, 5, 6], dtype=np.int64) num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), + input_values, cycle_length, block_length, num_parallel_calls), lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), + input_values, cycle_length * 2, block_length, num_parallel_calls), num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) # pylint: enable=g-long-lambda def testSparseCore(self): @@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 077abd6b30eafe857d27d84e533b15e4e98134e6..440e48db3095fe7006d510f7db80ad5327284659 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase): def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): get_next = ds_fn().make_one_shot_iterator().get_next() outputs = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(num_outputs): outputs.append(sess.run(get_next)) if verify_exhausted: diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 8b2f84649486e35e1067f5f9cbe4a7abec71e080..90d18dca2aa727ea51d636cb971f48b50bc0c663 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -32,18 +32,18 @@ from tensorflow.python.platform import test class SlideDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDataset(self, count, window_size, window_shift, window_stride): """Tests a dataset that slides a window its input elements.""" @@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDatasetDeprecated(self, count, window_size, stride, window_stride): @@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([[None] + list(c.shape[1:]) for c in components], [t.shape.as_list() for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (14, 0, 3, 1), - (14, 3, 0, 1), - (14, 3, 3, 0), + @parameterized.named_parameters( + ("1", 14, 0, 3, 1), + ("2", 14, 3, 0, 1), + ("3", 14, 3, 3, 0), ) def testSlideDatasetInvalid(self, count, window_size, window_shift, window_stride): @@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): window_stride=window_stride_t)).make_initializable_iterator()) init_op = iterator.initializer - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) num_batches = (10 - 5) // 3 + 1 for i in range(num_batches): @@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op) # Slide: 1st batch. actual = sess.run(get_next) @@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): .make_initializable_iterator()) next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) with self.assertRaisesRegexp( errors.InvalidArgumentError, diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 2c2cfbebff5d3eba00f120467102b4185d81ab24..52823d3fcace841ff0a68b8036c4e357f7c3c7b4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string), 2) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(2): # Run twice to verify statelessness of db operations. sess.run( init_op, @@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetJoinQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetNullTerminator(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetReuseSqlDataset(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadEmptyResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidDriverName(self): init_op = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string))[0] - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run( init_op, @@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithInvalidColumnName(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfQueryWithSyntaxError(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetOfInsertQuery(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.string)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int8` tensor. def testReadResultSetInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt8NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int8` tensor. def testReadResultSetInt8MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int16` tensor. def testReadResultSetInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetInt16NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int16` tensor. def testReadResultSetInt16MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in an `int32` tensor. def testReadResultSetInt32(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # table and place it in an `int32` tensor. def testReadResultSetInt32VarCharColumnAsInt(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in an `int64` tensor. def testReadResultSetInt64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # a SQLite database table and place it in an `int64` tensor. def testReadResultSetInt64MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # place it in a `uint8` tensor. def testReadResultSetUInt8(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint8` tensors. def testReadResultSetUInt8MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # and place it in a `uint16` tensor. def testReadResultSetUInt16(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # SQLite database table and place them in `uint16` tensors. def testReadResultSetUInt16MinAndMaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # in `bool` tensors. def testReadResultSetBool(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # from a SQLite database table and place it as `True` in a `bool` tensor. def testReadResultSetBoolNotZeroOrOne(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64OverlyPrecise(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ @@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, dtypes.float64)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( init_op, feed_dict={ diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 43067b4245d879aef9a40dc546b2a7742b3dc09c..e25570c5ad1e913c67c3c4339b3bdaf0523ccb04 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + def testPrefetchBufferUtilization(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + -1).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + float(i + 1)) + self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", + 0, 1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + 100) + def testReinitialize(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index 9a13acf8f0ac6690cad8847873768562da795496..2f5a44408fab5a686e5621660e7e3aca3e36954a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertLessEqual(min_value, value.histo.min) + self.assertGreaterEqual(max_value, value.histo.max) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasSum(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1d70b16041e902a5d08383887cbf647eac2e816c..4c3353fe4046d6b2bfabac580b46f88c8d7f2941 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -31,7 +31,7 @@ class DatasetTestBase(test.TestCase): # TODO(rachelim): support sparse tensor outputs next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: while True: try: op1 = sess.run(next1) @@ -52,9 +52,12 @@ class DatasetTestBase(test.TestCase): dataset2, exception_class, replacements=None): - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() - with self.test_session() as sess: + # We are defining next1 and next2 in the same line so that we get identical + # file:line_number in the error messages + # pylint: disable=line-too-long + next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() + # pylint: enable=line-too-long + with self.cached_session() as sess: try: sess.run(next1) raise ValueError( diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 0486e2bce20e9dcf81dcb5ac49fe5b397e44bf0c..8d335e87d549426f275768e874e4cbf466ed5dcc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -33,8 +33,17 @@ from tensorflow.python.platform import test class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters((1, None), (2, None), (4, None), (8, None), - (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + @parameterized.named_parameters( + ("1", 1, None), + ("2", 2, None), + ("3", 4, None), + ("4", 8, None), + ("5", 16, None), + ("6", 4, -1), + ("7", 4, 0), + ("8", 4, 1), + ("9", 4, 4), + ) def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): @@ -60,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) thread_ids = [] try: diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index d79a842e7a5d816e2e6a52fc83acbd6b260cf64b..f994c8563f6173a7d8943aaedc854a53e16dad24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for test_case, expected in test_cases: current_test_case = test_case sess.run(iterator.initializer) diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 33d95d67549e1c8d1d9af578fcebbb4f939c418a..6eaa0b195911acb057b30b8ca7408cdbfdce8352 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual(xs, ys) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetFlatMap(self, structure, shape, dtype): """Tests windowing by chaining it with flat map. @@ -92,20 +92,20 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).apply( grouping.window_dataset(5)).flat_map(fn) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run(self._structuredElement(structure, shape, dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchDense(self, structure, shape, dtype): """Tests batching of dense tensor windows. @@ -128,17 +128,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredElement(structure, np.concatenate( ([5], shape), axis=0), dtype)) actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchDenseDynamicShape(self, shape): """Tests batching of dynamically shaped dense tensor windows. @@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredElement(None, np.concatenate(([5], shape), axis=0), @@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchSparse(self, structure, shape, dtype): """Tests batching of sparse tensor windows. @@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shape, dtype).repeat(5).apply( grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredSparseElement(structure, np.concatenate(([5], shape), axis=0), @@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchSparseDynamicShape(self, shape): """Tests batching of dynamically shaped sparse tensor windows. @@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shape_t: shape}) expected = sess.run( self._structuredSparseElement(None, @@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ])) - @parameterized.parameters( - (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + @parameterized.named_parameters( + ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), ) def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, padded_shape): @@ -320,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping.window_dataset(len(shapes))).apply( grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( self._structuredElement( @@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1], [2], [3]]), [-1]), - (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int32([[1], [2], [3]]), [-1]), + ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped dense tensor windows. @@ -351,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape) expected = sess.run( @@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1]]), np.int32([0])), - (np.int32([[10], [20]]), np.int32([15])), + @parameterized.named_parameters( + ("1", np.int32([[1]]), np.int32([0])), + ("2", np.int32([[10], [20]]), np.int32([15])), ) def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. @@ -379,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) @@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + @parameterized.named_parameters( + ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), ) def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, padded_shape): @@ -456,17 +458,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): structure, shapes, dtype).apply(grouping.window_dataset( len(shapes))).apply(grouping._map_x_dataset(fn)) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: expected = sess.run( self._structuredRaggedSparseElement(structure, shapes, dtype, padded_shape)) actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1], [2], [3]]), [-1]), - (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int64([[1], [2], [3]]), [-1]), + ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): @@ -487,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(init_op, {shapes_t: shapes}) expected = sess.run( self._structuredRaggedSparseElement(None, shapes, dtypes.int32, @@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1]]), [0]), - (np.int64([[10], [20]]), [15]), + @parameterized.named_parameters( + ("1", np.int64([[1]]), [0]), + ("2", np.int64([[10], [20]]), [15]), ) def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of sparse tensor windows. @@ -514,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): grouping._map_x_dataset( lambda x: batching.padded_batch_window(x, padded_shape))) get_next = dataset.make_one_shot_iterator().get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index c603ecc5ab27a711557376246b093fd5f80f8aec..867ee2ba3794df77df64b3346138cdffb526abdc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase): return os.path.join(self.get_temp_dir(), "tf_record.out.txt") def testWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ self.filename: self._createFile(), @@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteZLIB(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ @@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase): def testWriteGZIP(self): options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run( self.writer, feed_dict={ diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 9c2001c34f4129c2530f2e882768658ab7fe5819..367c159dc5db688b652f2e88a92e44186d7c8bfd 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): padding_value = 0 def batch_init_fn(_): - return array_ops.fill( - array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), - constant_op.constant(padding_value, dtype=dataset.output_types)) + batch_shape = array_ops.concat( + [np.array([0], dtype=np.int32), padded_shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) def batch_reduce_fn(state, value): return array_ops.concat([state, [value]], 0) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6edc1d79902c571b34b6a0a108c4d62cb6097ccb..099e10db921b78fc9fa3bcf73979ae6c33bc1972 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func, bucket_batch_sizes, padded_shapes=None, padding_values=None, - pad_to_bucket_boundary=False): + pad_to_bucket_boundary=False, + no_padding=False): """A transformation that buckets elements in a `Dataset` by length. Elements of the `Dataset` are grouped together by length and then are padded @@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func, unknown size to bucket boundary minus 1 (i.e., the maximum length in each bucket), and caller must ensure that the source `Dataset` does not contain any elements with length longer than `max(bucket_boundaries)`. + no_padding: `bool`, indicates whether to pad the batch features (features + need to be either of type `tf.SparseTensor` or of same shape). Returns: A `Dataset` transformation function, which can be passed to @@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func, def batching_fn(bucket_id, grouped_dataset): """Batch elements in dataset.""" - batch_size = batch_sizes[bucket_id] + batch_size = window_size_fn(bucket_id) + if no_padding: + return grouped_dataset.batch(batch_size) none_filler = None if pad_to_bucket_boundary: err_msg = ("When pad_to_bucket_boundary=True, elements must have " diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 38c0a09c33b373efe5bd798a62026602db1a7c71..92d4251a864dae7d5725b0f177b54c5cbcc14aec 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -220,6 +220,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): if weights is None: # Select inputs with uniform probability. logits = [[1.0] * num_datasets] + else: # Use the given `weights` as the probability of choosing the respective # input. @@ -245,8 +246,11 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = random_ops.RandomDataset(seed).batch(2).map( - select_dataset_constant_logits) + selector_input = dataset_ops.MapDataset( + random_ops.RandomDataset(seed).batch(2), + select_dataset_constant_logits, + use_inter_op_parallelism=False) + else: # Use each element of the given `weights` dataset as the probability of # choosing the respective input. @@ -259,9 +263,12 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2) - )).map(select_dataset_varying_logits) + logits_and_seeds = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))) + selector_input = dataset_ops.MapDataset( + logits_and_seeds, + select_dataset_varying_logits, + use_inter_op_parallelism=False) return _DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py index 54d5cd6da068fa5471b7beafcc66d76b5972e7d5..3d0d0993c9209109465b9f428e905bb471cfc738 100644 --- a/tensorflow/contrib/data/python/ops/map_defun.py +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes): elems = [ops.convert_to_tensor(e) for e in elems] output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] - if not all(s.is_fully_defined() for s in output_shapes): - raise ValueError("All fn output shapes must be fully defined.") return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index fa1b851ad74bcf2cff69d42bce3eaa38822cd663..4114b62e29180932171fa222022961d632e4cabc 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -46,6 +46,21 @@ def assert_next(transformations): return _apply_fn +def model(): + """A transformation that models performance. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return _ModelDataset(dataset) + + return _apply_fn + + def optimize(optimizations=None): """A transformation that applies optimizations. @@ -97,6 +112,32 @@ class _AssertNextDataset(dataset_ops.Dataset): return self._input_dataset.output_types +class _ModelDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and models performance.""" + + def __init__(self, input_dataset): + """See `optimize()` for details.""" + super(_ModelDataset, self).__init__() + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.model_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 7f09ba71dc33389a198a96cfb292ef8904685f14..4c466781f7f659e8d7e267500a118d482d76da15 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -499,7 +499,8 @@ def make_csv_dataset( # indefinitely, and all batches will be full-sized. dataset = dataset.batch(batch_size=batch_size, drop_remainder=num_epochs is None) - dataset = dataset.map(map_fn) + dataset = dataset_ops.MapDataset( + dataset, map_fn, use_inter_op_parallelism=False) dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -778,7 +779,8 @@ def make_batched_features_dataset(file_pattern, # Extract values if the `Example` tensors are stored as key-value tuples. if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) + dataset = dataset_ops.MapDataset( + dataset, lambda _, v: v, use_inter_op_parallelism=False) # Apply dataset repeat and shuffle transformations. dataset = _maybe_shuffle_and_repeat( diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 30e1992c015d35859218d1b7fe3b2f3eb7c09b9b..91a27f97b7f75511db4b377220a353787beca30e 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -76,7 +76,7 @@ We then compile the Keras model and pass the `MirroredStrategy` object in the ```python model.compile(loss='mean_squared_error', optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2), - distribute=strategy) + distribute=distribution) ``` To train the model we call Keras `fit` API using the input dataset that we diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index c524d8b394afa664acf88f3e54eb125b061b2217..87f76eaa948e09f2a3b7fd0ba52a154824e9fe33 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -708,19 +708,32 @@ cuda_py_test( ], ) -cuda_py_test( - name = "keras_test", +py_library( + name = "keras_test_lib", + testonly = 1, srcs = ["keras_test.py"], - additional_deps = [ - "//third_party/py/numpy", + deps = [ + ":combinations", "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:client_testlib", "//tensorflow/python:training", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + ":keras_test_lib", ], tags = [ "multi_and_single_gpu", + "no_pip", "no_windows_gpu", "notsan", ], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 4fa8aa06cce38e1be0bf0b87951127499fdcc44f..77079d0df9a94254384e75b98a0f6432189f05d8 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -229,6 +229,8 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = True + assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 2301ba9233d29a1e5d054e71e4d9383af8bd48fd..244d1fcec8ba481337afeede181c29d0552e3c44 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -50,10 +50,12 @@ from tensorflow.contrib.cluster_resolver import TPUClusterResolver from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib +from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2 from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent @@ -328,6 +330,10 @@ tpu_strategy = NamedDistribution( "TPU", lambda: tpu_lib.TPUStrategy( TPUClusterResolver(""), steps_per_run=5), required_tpu=True) +tpu_strategy_one_step = NamedDistribution( + "TPU", lambda: tpu_lib.TPUStrategy( + TPUClusterResolver(""), steps_per_run=1), + required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( @@ -343,17 +349,23 @@ mirrored_strategy_with_two_gpus = NamedDistribution( adam_optimizer_v1_fn = NamedObject( - "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) + "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) -optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn] +adagrad_optimizer_v1_fn = NamedObject( + "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) +optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, + adagrad_optimizer_v1_fn] adam_optimizer_v2_fn = NamedObject( - "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1)) + "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1)) gradient_descent_optimizer_v2_fn = NamedObject( "GradientDescentV2", lambda: gradient_descent_v2.GradientDescentOptimizer(0.2)) -optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn] +adagrad_optimizer_v2_fn = NamedObject( + "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001)) +optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn, + adagrad_optimizer_v2_fn] graph_and_eager_modes = ["graph", "eager"] diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a20069c4fe4713897ba9543cd56615db7a2fc3cb..a84ef041960e389c08246fc8a16df2300856d968 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -58,13 +58,12 @@ def get_input_datasets(): train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() train_ds = train_ds.shuffle(100) - train_ds = train_ds.batch(64) + train_ds = train_ds.batch(64, drop_remainder=True) # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() - eval_ds = eval_ds.shuffle(100) - eval_ds = eval_ds.batch(64) + eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d39fd57294a67a4a98a528f2aa99f0436f245847..5f35e381899a03f12cf0a6ed0168b9e500d41801 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -18,9 +18,12 @@ from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np +from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops @@ -31,6 +34,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import distributed_training_utils +from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -63,6 +67,32 @@ def simple_functional_model(): return model +def multi_inputs_multi_outputs_model(): + input_a = keras.layers.Input(shape=(16,), name='input_a') + input_b = keras.layers.Input(shape=(16,), name='input_b') + input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m') + dense = keras.layers.Dense(8, name='dense_1') + + interm_a = dense(input_a) + # Read m + interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m) + interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a]) + interm_b = dense(input_b) + merged = keras.layers.concatenate([interm_s, interm_b], name='merge') + output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged) + output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged) + model = keras.models.Model( + inputs=[input_a, input_b, input_m], outputs=[output_c, output_d]) + model.compile( + loss='categorical_crossentropy', + optimizer=gradient_descent.GradientDescentOptimizer(0.001), + metrics={ + 'dense_2': 'categorical_accuracy', + 'dense_3': 'categorical_accuracy' + }) + return model + + def get_ds_train_input_fn(): np.random.seed(_RANDOM_SEED) (x_train, y_train), _ = testing_utils.get_test_data( @@ -91,6 +121,68 @@ def get_ds_test_input_fn(): return dataset +def get_multi_inputs_multi_outputs_data(): + (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=3, + random_seed=_RANDOM_SEED) + (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(16,), + num_classes=2, + random_seed=_RANDOM_SEED) + (m_train, _), (m_test, _) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(8,), + num_classes=2, + random_seed=_RANDOM_SEED) + + c_train = keras.utils.to_categorical(c_train) + c_test = keras.utils.to_categorical(c_test) + d_train = keras.utils.to_categorical(d_train) + d_test = keras.utils.to_categorical(d_test) + + train_data = { + 'input_a': a_train, + 'input_b': b_train, + 'input_m': m_train, + 'output_c': c_train, + 'output_d': d_train + } + test_data = { + 'input_a': a_test, + 'input_b': b_test, + 'input_m': m_test, + 'output_c': c_test, + 'output_d': d_test + } + + return (train_data, test_data) + + +def batch_wrapper(dataset, batch_size, distribution): + # TPUs currently require fully defined input shapes, drop_remainder ensures + # the input will have fully defined shapes. + if isinstance(distribution, tpu_strategy.TPUStrategy): + return dataset.batch(batch_size, drop_remainder=True) + else: + return dataset.batch(batch_size) + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy_one_step], + mode=['graph']) + + class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): def setUp(self): @@ -99,6 +191,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.MakeDirs(self._base_dir) self._config = run_config_lib.RunConfig( tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + self._dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) def tearDown(self): writer_cache.FileWriterCache.clear() @@ -152,6 +246,53 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self): + train_data, test_data = get_multi_inputs_multi_outputs_data() + + def train_input_fn(): + input_dict = { + 'input_a': train_data['input_a'], + 'input_b': train_data['input_b'], + 'input_m': train_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': train_data['output_c'], + 'dense_3': train_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + def eval_input_fn(): + input_dict = { + 'input_a': test_data['input_a'], + 'input_b': test_data['input_b'], + 'input_m': test_data['input_m'].astype(np.str) + } + output_dict = { + 'dense_2': test_data['output_c'], + 'dense_3': test_data['output_d'] + } + return dataset_ops.Dataset.from_tensor_slices((input_dict, + output_dict)).batch(16) + + self.do_test_multi_inputs_multi_outputs_with_input_fn( + train_input_fn, eval_input_fn) + + def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn, + eval_input_fn): + config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=self._dist) + with self.cached_session(): + model = multi_inputs_multi_outputs_model() + est_keras = keras_lib.model_to_estimator(keras_model=model, config=config) + baseline_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) + eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(eval_results['loss'], baseline_eval_results['loss']) + def test_keras_optimizer_with_distribution_strategy(self): dist = mirrored_strategy.MirroredStrategy( devices=['/device:GPU:0', '/device:GPU:1']) @@ -175,7 +316,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): gfile.DeleteRecursively(self._config.model_dir) -class TestWithDistributionStrategy(test.TestCase): +class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_validating_dataset_input_tensors_with_shape_mismatch(self): with self.cached_session(): @@ -215,7 +356,7 @@ class TestWithDistributionStrategy(test.TestCase): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y) - def test_calling_model_on_same_dataset(self): + def test_calling_model_with_numpy_arrays(self): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -228,11 +369,44 @@ class TestWithDistributionStrategy(test.TestCase): '/device:GPU:0']) model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + + @combinations.generate(all_combinations()) + def test_calling_model_on_same_dataset(self, distribution): + with self.cached_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) # Call fit with validation data model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, @@ -241,6 +415,9 @@ class TestWithDistributionStrategy(test.TestCase): validation_data=dataset, validation_steps=2) model.predict(dataset, steps=2) + # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work + # as clone_model's input_tensors argument only seems to accept list and not + # tuples or dict. def test_fit_with_tuple_and_dict_dataset_inputs(self): with self.cached_session(): a = keras.layers.Input(shape=(3,), name='input_a') @@ -282,7 +459,8 @@ class TestWithDistributionStrategy(test.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - def test_fit_eval_and_predict_methods_on_dataset(self): + @combinations.generate(all_combinations()) + def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) @@ -291,16 +469,13 @@ class TestWithDistributionStrategy(test.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', - '/device:CPU:0']) - - model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + model.compile(optimizer, loss, metrics=metrics, distribute=distribution) inputs = np.zeros((10, 3), dtype=np.float32) targets = np.zeros((10, 4), dtype=np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = batch_wrapper(dataset, 10, distribution) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) @@ -446,8 +621,7 @@ class TestWithDistributionStrategy(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - 'expected input to have 2 dimensions'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) # Wrong input shape @@ -497,6 +671,8 @@ class TestWithDistributionStrategy(test.TestCase): class LossMaskingWithDistributionStrategyTest(test.TestCase): + # TODO(priyag): Enable all strategies for this test. Currently it does not + # work for TPU due to some invalid datatype. def test_masking(self): with self.cached_session(): np.random.seed(1337) @@ -520,24 +696,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): self.assertEqual(hist.history['loss'][0], 0) -class NormalizationLayerWithDistributionStrategyTest(test.TestCase): +class NormalizationLayerWithDistributionStrategyTest( + test.TestCase, parameterized.TestCase): - def test_batchnorm_correctness(self): + @combinations.generate(all_combinations()) + def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) model.add(norm) - strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0', - '/device:GPU:0']) model.compile(loss='mse', optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=strategy) + distribute=distribution) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) + x = x.astype('float32') dataset = dataset_ops.Dataset.from_tensor_slices((x, x)) dataset = dataset.repeat(100) - dataset = dataset.batch(32) + dataset = batch_wrapper(dataset, 32, distribution) model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10) out = model.predict(dataset, steps=2) @@ -547,9 +724,11 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) -class CorrectnessWithDistributionStrategyTest(test.TestCase): +class CorrectnessWithDistributionStrategyTest(test.TestCase, + parameterized.TestCase): - def test_correctness(self): + @combinations.generate(all_combinations()) + def test_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') num_samples = 10000 @@ -558,43 +737,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): x_train = x_train.astype('float32') y_train = y_train.astype('float32') - model = keras.Sequential() - model.add(keras.layers.Dense(1, input_shape=(1,))) - - # With DistributionStrategy - dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_with = dataset_with.batch(32) - strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', - '/device:GPU:0']) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5), - distribute=strategy) - model.fit(x=dataset_with, epochs=1, steps_per_epoch=310) - wts_with_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict, - x_predict)) - predict_dataset_with = predict_dataset_with.batch(2) - predict_with_ds = model.predict(predict_dataset_with, steps=1) - predict_with_ds = np.reshape(predict_with_ds, (4, 1)) - - # Without DistributionStrategy - dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train, + def fit_and_predict(with_distribution=None): + model = keras.Sequential() + model.add(keras.layers.Dense(1, input_shape=(1,))) + model.compile( + loss=keras.losses.mean_squared_error, + optimizer=gradient_descent.GradientDescentOptimizer(0.5), + distribute=with_distribution) + + batch_size = 64 + if with_distribution: + batch_size //= with_distribution.num_towers + train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) - dataset_without = dataset_without.batch(64) - - model.compile(loss=keras.losses.mean_squared_error, - optimizer=gradient_descent.GradientDescentOptimizer(0.5)) - model.fit(x=dataset_without, epochs=1, steps_per_epoch=310) - wts_without_ds = model.get_weights() - - x_predict = [[1], [2], [3], [4]] - predict_dataset_without = dataset_ops.Dataset.from_tensor_slices(( - x_predict, x_predict)) - predict_dataset_without = predict_dataset_without.batch(4) - predict_without_ds = model.predict(predict_dataset_without, steps=1) + train_dataset = batch_wrapper(train_dataset, batch_size, distribution) + # Running only 100 steps instead of the full dataset to keep test + # duration small. + model.fit(x=train_dataset, epochs=1, steps_per_epoch=100) + + weights = model.get_weights() + + x_predict = [[1.], [2.], [3.], [4.]] + predict_batch_size = 4 + if with_distribution: + predict_batch_size //= with_distribution.num_towers + predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict, + x_predict)) + predict_dataset = batch_wrapper(predict_dataset, + predict_batch_size, distribution) + predict_result = model.predict(predict_dataset, steps=1) + predict_result = np.reshape(predict_result, (4, 1)) + + return weights, predict_result + + wts_with_ds, predict_with_ds = fit_and_predict( + with_distribution=distribution) + wts_without_ds, predict_without_ds = fit_and_predict( + with_distribution=None) # Verify that the weights are the same within some limits of tolerance. np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3) @@ -603,5 +782,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3) +# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1. + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index bdac4fb58c2ca8c4f6a322a6f477a9e3657b8f93..ba147e78241e5ab45809e498e00debd45a2c49b4 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -183,6 +183,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): "dense/kernel", "dense/bias", "beta1_power", "beta2_power", "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam", "dense/bias/Adam_1" + ], + "Adagrad": [ + "dense/kernel/Adagrad", "dense/kernel", + "dense/bias/Adagrad", "dense/bias" ] } variables = variables_map[optimizer_fn().get_name()] diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d1235b7afb31b29cb101b2d900ae703515ead650..0c6805d68218029abcad784b476b76bf3d368a9f 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -572,6 +572,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): task_type=None, task_id=None): del task_type, task_id + + if session_config: + session_config.isolate_session_state = True + if cluster_spec: self._initialize_multi_worker(self._num_gpus, cluster_spec) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 88d7768b1447bd58e2c6349a2302f151dd34527d..1125d027f64420863386d4fbd9db5564a5847825 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -412,6 +412,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = False + assert self._cluster_spec assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py index bb10b546a1907bba26cd0d7e7c5308420adbaf3f..16799104e8112f4391152c0cf2a15af81f8c2c9d 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -55,14 +55,14 @@ class PrefetchingOpsV2Test(test.TestCase): next_element = iterator.get_next() output = [] + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: - for _ in range(5): + for _ in range(4): result = sess.run(next_element) self.assertEqual(2, len(result)) output.extend(result) - self.assertEquals(set(range(10)), set(output)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + self.assertEquals(set(range(8)), set(output)) def testPrefetchToTwoDevicesWithReinit(self): if not test_util.is_gpu_available(): @@ -75,14 +75,14 @@ class PrefetchingOpsV2Test(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. with self.cached_session() as sess: sess.run(iterator.initializer) - for _ in range(5): - sess.run(next_element) - with self.assertRaises(errors.OutOfRangeError): + for _ in range(4): sess.run(next_element) sess.run(iterator.initializer) - for _ in range(5): + for _ in range(4): sess.run(next_element) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 32d7444e42cd2e12c0f41c4e53c54e3fae0dfa0a..6ba83976fcd47fe1680992fbbd5bb56ffa68071d 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -310,4 +310,18 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' - return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker' + return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id) + + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + del cluster_spec, task_type, task_id + if session_config: + session_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 97c53ae2b94988ad9938c9d1cf3326e4076e8d6f..9aadc634da5a7591747a4f651cdb45376393402d 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -166,6 +166,7 @@ cuda_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], + tags = ["notap"], ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index a7bd51430e384c199ca8abd06ef9887e998cc380..1e36b7ff9be4018f6b80a89e5967e5e21e9bd275 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator +from tensorflow.python.ops.linalg import linalg from tensorflow.python.platform import test diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 196cc413353657c2dfadd3a1c87b97518c6f235b..13370497ce706a60b1d0c7f4f148076b354626a7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -22,7 +22,6 @@ import numpy as np from scipy import stats from tensorflow.contrib import distributions -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -30,6 +29,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.platform import test bs = bijectors diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 25f29452c3949600b8a4153a8585dd7269bd3b2b..ba31697c589006c9fbee2fe68639e5f1daf51f62 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import dtypes @@ -29,6 +28,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 6959b3e8775d2dd488b4ee3252d143ef376d58f9..b4ad33cf6dbf073419a27f378c8eefdba97c5af7 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond @@ -27,6 +26,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.distributions import distribution as distribution_lib # The following two lines are redundant, in a sense. The first enables diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index d8401801f21afbe8fd042053c6a38a31a2539438..74d9d04fc702a90a5fc5a31f554abe257dd2860d 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d9110947ecdbba1a63669573f46db17b02e512ab..c6a23e4336fffbf7b61490dd3468bc71c7f421cc 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -18,10 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index f1accaaa4c920344608015c792a2c3606de1337f..49b9de0ab508f5db090bb1349f596da1b2a71b49 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -21,7 +21,6 @@ from __future__ import print_function import math import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util from tensorflow.python.framework import constant_op @@ -36,6 +35,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.linalg import linalg from tensorflow.python.util import deprecation __all__ = [ diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 7d2274db9b051e604266074651f4cbd331f20f48..48d093e0754f79725f3e3e900320773aae41e8ad 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"].numpy()) def testDatasetGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) init_op, call_op, results_op = e.evaluate_on_dataset(ds) @@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase): self.assertEqual(6.0, results["mean"]) def testWriteSummariesGraph(self): - with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + with context.graph_mode(), ops.Graph().as_default(), self.cached_session(): e = SimpleEvaluator(IdentityModel()) ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 529c99b37c7c37e70afe0d95ccca15200afce60b..3acecd283cda83992bab0c37cf0b8037ed2cf27a 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 40bc09872482c6062a870a3c274ba792ab83f3de..e0d5e494d432b365b0d1dcff6b634de2e6213a43 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index f1e1f99c57a77a6c6d3cb0578e1f1c776933605d..560fc8c5a22a0e7acf1f37cf7daf7790dc14de19 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md index fabd7b3e206d3a1954893a2b75361146d4709d00..750bbc66f3555a5d30ac1fd81d87ff54f7389f64 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md @@ -23,4 +23,4 @@ Attribution-ShareAlike License and is available at https://en.wikipedia.org/wiki/List_of_colors:_N-Z This example was adapted from - https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot + https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD deleted file mode 100644 index 638c57d1c92c1dce0ef9e73e9a6ac2369358080b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -cuda_py_test( - name = "scan_test", - size = "small", - srcs = ["scan_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "scan_graph_test", - size = "small", - srcs = ["scan_graph_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py deleted file mode 100644 index d4b8c8941ec411912f3089315d038fc4bcd049ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under graph mode execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - with tf.Session() as sess: - sess.run(sum_op) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py deleted file mode 100644 index a02fc24c79dae6c2565db8b138b1d7391d169ed8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit test for tf.scan under eager execution.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import time - -import numpy as np -import tensorflow as tf - - -class ScanBenchmark(tf.test.Benchmark): - - def runScan(self, n): - elems = np.arange(n) - start_time = time.time() - _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) - wall_time = time.time() - start_time - - self.report_benchmark( - name='scan', - iters=n, - wall_time=wall_time) - - def benchmarkScan16000(self): - self.runScan(16000) - - def benchmarkScan32000(self): - self.runScan(32000) - - def benchmarkScan64000(self): - self.runScan(64000) - - def benchmarkScan128000(self): - self.runScan(128000) - - -if __name__ == '__main__': - tf.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index dcc7b71d79f207019cec4425eb000b92420b9ca7..9d2d172752c7f3f3ee6eaa11ab8952313a4a3543 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -216,7 +216,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.numer.name, "has_space/numer:0") def testGraphWithPlaceholder(self): - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) accumulate = m(p) @@ -309,7 +309,7 @@ class MetricsTest(test.TestCase): self.assertTrue(old_numer is m.numer) def testMetricsChain(self): - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): m1 = metrics.Mean() m2 = metrics.Mean(name="m2") update_m2 = m2(3.0) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 77f62df99d5a052e2df61d3f225e1860d4d1da72..6db311d52de61359995087fb5ca3d5461f74c4c1 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -18,6 +18,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":dnn_with_layer_annotations", ":early_stopping", ":export", ":exporter", @@ -126,6 +127,61 @@ py_test( ], ) +py_library( + name = "dnn_with_layer_annotations", + srcs = ["python/estimator/dnn_with_layer_annotations.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:partitioned_variables", + "//tensorflow/python:summary", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:utils", + ], +) + +py_test( + name = "dnn_with_layer_annotations_test", + size = "medium", + srcs = ["python/estimator/dnn_with_layer_annotations_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", # b/67510291 + ], + deps = [ + ":dnn_with_layer_annotations", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:pandas_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "@six_archive//:six", + ], +) + py_library( name = "dnn_linear_combined", srcs = ["python/estimator/dnn_linear_combined.py"], @@ -446,6 +502,7 @@ py_library( "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 258860f26340a0934e854f2d1950ead60e413234..78914ecacaf79fd25b33d4159601ab49d2b74c96 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.early_stopping import * from tensorflow.contrib.estimator.python.estimator.export import * @@ -76,6 +77,8 @@ _allowed_symbols = [ 'build_raw_supervised_input_receiver_fn', 'build_supervised_input_receiver_fn_from_input_fn', 'SavedModelEstimator' + 'DNNClassifierWithLayerAnnotations', + 'DNNRegressorWithLayerAnnotations', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py new file mode 100644 index 0000000000000000000000000000000000000000..152431d1b205845945cc2c079b747f81d739026f --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py @@ -0,0 +1,434 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Deep Neural Network estimators with layer annotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import pickle + +from google.protobuf.any_pb2 import Any + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.ops.losses import losses +from tensorflow.python.saved_model import utils as saved_model_utils + + +class LayerAnnotationsCollectionNames(object): + """Names for the collections containing the annotations.""" + + UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features' + PROCESSED_FEATURES = 'layer_annotatons/processed_features' + FEATURE_COLUMNS = 'layer_annotations/feature_columns' + + @classmethod + def keys(cls, collection_name): + return '%s/keys' % collection_name + + @classmethod + def values(cls, collection_name): + return '%s/values' % collection_name + + +def serialize_feature_column(feature_column): + if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access + # We can't pickle nested functions, and we don't need the value of + # layer_creator in most cases anyway, so just discard its value. + args = feature_column._asdict() + args['layer_creator'] = None + temp = type(feature_column)(**args) + return pickle.dumps(temp) + return pickle.dumps(feature_column) + + +def _to_any_wrapped_tensor_info(tensor): + """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`.""" + any_buf = Any() + tensor_info = saved_model_utils.build_tensor_info(tensor) + any_buf.Pack(tensor_info) + return any_buf + + +def make_input_layer_with_layer_annotations(original_input_layer, mode): + """Make an input_layer replacement function that adds layer annotations.""" + + def input_layer_with_layer_annotations(features, + feature_columns, + weight_collections=None, + trainable=True, + cols_to_vars=None, + cols_to_output_tensors=None): + """Returns a dense `Tensor` as input layer based on given `feature_columns`. + + Generally a single example in training data is described with + FeatureColumns. + At the first layer of the model, this column oriented data should be + converted + to a single `Tensor`. + + This is like tf.feature_column.input_layer, except with added + Integrated-Gradient annotations. + + Args: + features: A mapping from key to tensors. `_FeatureColumn`s look up via + these keys. For example `numeric_column('price')` will look at 'price' + key in this dict. Values can be a `SparseTensor` or a `Tensor` depends + on corresponding `_FeatureColumn`. + feature_columns: An iterable containing the FeatureColumns to use as + inputs to your model. All items should be instances of classes derived + from `_DenseColumn` such as `numeric_column`, `embedding_column`, + `bucketized_column`, `indicator_column`. If you have categorical + features, you can wrap them with an `embedding_column` or + `indicator_column`. + weight_collections: A list of collection names to which the Variable will + be added. Note that variables will also be added to collections + `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + cols_to_vars: If not `None`, must be a dictionary that will be filled with + a mapping from `_FeatureColumn` to list of `Variable`s. For example, + after the call, we might have cols_to_vars = {_EmbeddingColumn( + categorical_column=_HashedCategoricalColumn( key='sparse_feature', + hash_bucket_size=5, dtype=tf.string), dimension=10): [